diff --git a/rtp_llm/cpp/engine_base/stream/StreamGroups.h b/rtp_llm/cpp/engine_base/stream/StreamGroups.h index 128beed36b..f61b8bc185 100644 --- a/rtp_llm/cpp/engine_base/stream/StreamGroups.h +++ b/rtp_llm/cpp/engine_base/stream/StreamGroups.h @@ -36,7 +36,12 @@ struct StreamGroups { has_multimodal_input_ = true; } } - total_block_update_copy_num_ += stream->streamCacheResource().getKVBlockUpdateMapping().size(); + auto block_update_copy_num = stream->streamCacheResource().getKVBlockUpdateMapping().size(); + if (stream->isContextStream()) { + context_block_update_copy_num_ += block_update_copy_num; + } else { + decode_block_update_copy_num_ += block_update_copy_num; + } model_execute_token_size_ += stream->currentExecuteTokenSize(); total_sampler_batch_size_in_ += stream->needTilingForSampling() ? next_batch_size : cur_batch_size; total_sampler_batch_size_out_ += next_batch_size; @@ -64,7 +69,13 @@ struct StreamGroups { return total_sampler_batch_size_out_; } size_t totalBlockUpdateCopyNum() const { - return total_block_update_copy_num_; + return decode_block_update_copy_num_ + context_block_update_copy_num_; + } + size_t decodeBlockUpdateCopyNum() const { + return decode_block_update_copy_num_; + } + size_t contextBlockUpdateCopyNum() const { + return context_block_update_copy_num_; } size_t curBlocksNum() const { return max_blocks_num_; @@ -173,7 +184,7 @@ struct StreamGroups { << ", total_model_batch_size: " << totalModelBatchSize() << ", total_sampler_batch_size_in: " << total_sampler_batch_size_in_ << ", total_sampler_batch_size_out: " << total_sampler_batch_size_out_ - << ", total_block_update_copy_num: " << total_block_update_copy_num_ + << ", total_block_update_copy_num: " << totalBlockUpdateCopyNum() << ", max_blocks_num_: " << max_blocks_num_ << ", model_execute_token_size: " << model_execute_token_size_ << ", max_seq_len: " << max_seq_len_ << ", is_fake_stream: " << is_fake_stream_ << "}"; @@ -195,22 +206,23 @@ struct StreamGroups { private: std::list context_streams_; std::list decode_streams_; - size_t total_sampler_batch_size_in_ = 0; - size_t total_sampler_batch_size_out_ = 0; - size_t total_decode_batch_size_ = 0; - size_t total_context_batch_size_ = 0; - size_t total_block_update_copy_num_ = 0; - size_t max_blocks_num_ = 0; - size_t model_execute_token_size_ = 0; - size_t max_seq_len_ = 0; - size_t max_context_seq_len_ = 0; - size_t max_reuse_length_ = 0; - size_t cum_context_seq_len_ = 0; - size_t multimodal_features_len_ = 0; - size_t total_score_batch_size_ = 0; - bool has_multimodal_input_ = false; - bool gen_timeline_ = false; - bool is_fake_stream_ = false; + size_t total_sampler_batch_size_in_ = 0; + size_t total_sampler_batch_size_out_ = 0; + size_t total_decode_batch_size_ = 0; + size_t total_context_batch_size_ = 0; + size_t decode_block_update_copy_num_ = 0; + size_t context_block_update_copy_num_ = 0; + size_t max_blocks_num_ = 0; + size_t model_execute_token_size_ = 0; + size_t max_seq_len_ = 0; + size_t max_context_seq_len_ = 0; + size_t max_reuse_length_ = 0; + size_t cum_context_seq_len_ = 0; + size_t multimodal_features_len_ = 0; + size_t total_score_batch_size_ = 0; + bool has_multimodal_input_ = false; + bool gen_timeline_ = false; + bool is_fake_stream_ = false; std::list adapter_names; }; diff --git a/rtp_llm/cpp/normal_engine/NormalBatchStreamProcessor.cc b/rtp_llm/cpp/normal_engine/NormalBatchStreamProcessor.cc index ff34bf0ebb..f628d2910d 100644 --- a/rtp_llm/cpp/normal_engine/NormalBatchStreamProcessor.cc +++ b/rtp_llm/cpp/normal_engine/NormalBatchStreamProcessor.cc @@ -1,651 +1,72 @@ -#include -#include -#include -#include -#include -#include -#include "c10/core/DeviceType.h" -#include "c10/core/ScalarType.h" -#include "rtp_llm/cpp/models/Sampler.h" -#include "rtp_llm/cpp/utils/AssertUtils.h" -#include "rtp_llm/cpp/core/Types.h" -#include "rtp_llm/cpp/cache/Types.h" #include "rtp_llm/cpp/normal_engine/NormalBatchStreamProcessor.h" -#include "rtp_llm/cpp/models/logits_processor/LogitsProcessorStates.h" -#include "rtp_llm/cpp/models/SampleInfos.h" -#include "rtp_llm/cpp/utils/TensorDebugUtils.h" -#if USING_CUDA -#include "rtp_llm/cpp/cuda/ops/StandaloneOps.h" -#include "ATen/cuda/CUDAContext.h" -#endif - -using namespace std; namespace rtp_llm { -absl::StatusOr NormalBatchStreamProcessor::gatherModelInput(const StreamGroups& stream_groups) const { - RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); - auto context_streams = stream_groups.contextStreams(); - auto decode_streams = stream_groups.decodeStreams(); - RTP_LLM_LOG_DEBUG( - "context_streams size = %d, decode_streams size = %d", context_streams.size(), decode_streams.size()); - GptModelInputs model_input; - const size_t current_tokens_size = stream_groups.modelExecuteTokenSize(); - const size_t total_batch_size = stream_groups.totalModelBatchSize(); - const size_t total_decode_batch_size = stream_groups.totalDecodeBatchSize(); - const size_t total_context_batch_size = stream_groups.totalContextBatchSize(); - const size_t total_block_copy_num = stream_groups.totalBlockUpdateCopyNum(); - const size_t max_blocks_num = stream_groups.curBlocksNum(); - const size_t multimodal_features_len = stream_groups.mmFeaturesLen(); - - const bool has_multimodal_input = is_multimodal_ && stream_groups.has_multimodal_input(); - const bool need_cal_position_id = (mm_position_ids_style_ != PositionIdsStyle::DEFAULT) || has_positional_encoding_; - - size_t num_layers = 0; - if (model_input.kv_cache_layer_to_group.defined()) { - num_layers = model_input.kv_cache_layer_to_group.numel(); - } else { - num_layers = layer_to_kv_cache_group_id_.size(); - } - - // Use pinned_memory(true) in TensorOptions to leverage PyTorch's CachingHostAllocator, - // which reuses pinned memory blocks across calls instead of cudaHostAlloc/Free each time. - static const auto pinned_i32 = torch::TensorOptions(torch::kInt32).pinned_memory(true); - static const auto pinned_i64 = torch::TensorOptions(torch::kInt64).pinned_memory(true); - static const auto pinned_bool = torch::TensorOptions(torch::kBool).pinned_memory(true); - - model_input.combo_tokens = torch::empty({(int64_t)current_tokens_size}, pinned_i32); - if (max_blocks_num) { - model_input.kv_cache_kernel_block_id = torch::zeros({(int64_t)kv_cache_group_nums_, - (int64_t)total_batch_size, - (int64_t)(max_blocks_num * kernel_blocks_per_kv_block_)}, - pinned_i32); - model_input.kv_cache_block_id = torch::zeros( - {(int64_t)kv_cache_group_nums_, (int64_t)total_batch_size, (int64_t)max_blocks_num}, pinned_i32); - model_input.kv_cache_layer_to_group = torch::empty({(int64_t)num_layers_}, pinned_i32); - model_input.kv_cache_group_types = torch::empty({(int64_t)kv_cache_group_nums_}, pinned_i32); - model_input.kv_cache_update_mapping = torch::empty({(int64_t)total_block_copy_num, 2}, pinned_i32); - model_input.cache_keys = torch::empty({(int64_t)total_context_batch_size, (int64_t)max_blocks_num}, pinned_i64); - } - model_input.request_id = torch::empty({(int64_t)total_context_batch_size}, pinned_i64); - model_input.request_pd_separation = torch::empty({(int64_t)total_context_batch_size}, pinned_bool); - model_input.input_lengths = torch::empty({(int64_t)total_batch_size}, pinned_i32); - model_input.sequence_lengths = torch::empty({(int64_t)total_decode_batch_size}, pinned_i32); - model_input.lm_output_indexes = torch::empty({(int64_t)total_batch_size}, pinned_i32); - model_input.lm_output_lengths = torch::empty({(int64_t)total_batch_size}, pinned_i32); - model_input.prefix_lengths = torch::empty({(int64_t)total_context_batch_size}, pinned_i32); - if (need_cal_position_id) { - model_input.combo_position_ids = - torch::empty({(int64_t)(current_tokens_size * position_id_len_factor_)}, pinned_i32); - } - if (has_multimodal_input) { - model_input.text_tokens_mask = torch::empty({(int64_t)current_tokens_size}, pinned_i32); - model_input.mm_features_locs = torch::empty({(int64_t)multimodal_features_len}, pinned_i32); - } - model_input.kv_block_stride_bytes = block_stride_bytes_; - model_input.kv_scale_stride_bytes = scale_stride_bytes_; - model_input.seq_size_per_block = seq_size_per_block_; - model_input.kernel_seq_size_per_block = kernel_seq_size_per_block_; - model_input.pd_separation = role_type_ == RoleType::PREFILL; - model_input.warmup = warm_up_; - model_input.decode_entrance = decode_entrance_; - model_input.is_fake_stream = stream_groups.isFakeStream(); - - int* merged_tokens = model_input.combo_tokens.data_ptr(); - int* input_lengths = model_input.input_lengths.data_ptr(); - int* sequence_lengths = model_input.sequence_lengths.data_ptr(); - int* lm_output_indexes = model_input.lm_output_indexes.data_ptr(); - int* lm_output_lengths = model_input.lm_output_lengths.data_ptr(); - int* prefix_lengths = model_input.prefix_lengths.data_ptr(); - int* combo_position_ids = need_cal_position_id ? model_input.combo_position_ids.data_ptr() : nullptr; - int* merged_text_mask = has_multimodal_input ? model_input.text_tokens_mask.data_ptr() : nullptr; - int* mm_features_locs = has_multimodal_input ? model_input.mm_features_locs.data_ptr() : nullptr; - int batch_idx = 0; - int input_vocab_size = input_vocab_size_ ? input_vocab_size_ : vocab_size_; - - if (model_input.kv_cache_layer_to_group.defined()) { - std::memcpy(model_input.kv_cache_layer_to_group.data_ptr(), - layer_to_kv_cache_group_id_.data(), - static_cast(num_layers) * sizeof(int32_t)); - } - - if (model_input.kv_cache_group_types.defined()) { - auto* dst = model_input.kv_cache_group_types.data_ptr(); - for (size_t g = 0; g < kv_cache_group_nums_; ++g) { - dst[g] = static_cast(kv_cache_group_types_[g]); - } - } - - auto* kv_cache_update_mapping = model_input.kv_cache_update_mapping.defined() ? - (BlockIdPair*)model_input.kv_cache_update_mapping.data_ptr() : - nullptr; - const auto add_cache_update_copy = [&](const auto& update_mapping) { - size_t update_copy_num = update_mapping.size(); - std::memcpy(kv_cache_update_mapping, update_mapping.data(), update_copy_num * sizeof(BlockIdPair)); - kv_cache_update_mapping += update_copy_num; - }; - - if (merged_text_mask) { - std::fill(merged_text_mask, merged_text_mask + current_tokens_size, 1); - } - - for (const auto& stream : decode_streams) { - model_input.need_all_logits = model_input.need_all_logits || stream->calculateLoss(); - auto current_batch_size = stream->currentBatchSize(); - - auto& kv_cache = *stream->kvCachePtr(); - RTP_LLM_LOG_DEBUG("decode kv_cache: %s", kv_cache.debugString().c_str()); - RTP_LLM_LOG_DEBUG("decode stream: %s", stream->debugString().c_str()); - - for (auto i = 0; i < current_batch_size; ++i) { - model_input.trace_ids.push_back(stream->traceId()); - - auto currentTokens = stream->currentExecuteTokens(i); - if (currentTokens[0] >= input_vocab_size) { - std::ostringstream error_msg; - error_msg << "stream [" << stream->streamId() << "] token_id " << currentTokens[0] - << " exceed vocab_size " << input_vocab_size; - return absl::InvalidArgumentError(error_msg.str()); - } - merged_tokens[batch_idx] = currentTokens[0]; - input_lengths[batch_idx] = stream->inputLength(); - sequence_lengths[batch_idx] = stream->seqLength() - 1; // need remove - if (need_cal_position_id) { - stream->generateNextPositionId(combo_position_ids + batch_idx * position_id_len_factor_); - } - lm_output_indexes[batch_idx] = batch_idx; - lm_output_lengths[batch_idx] = 1; - if (max_blocks_num) { - RTP_LLM_CHECK_WITH_INFO(model_input.kv_cache_kernel_block_id.dim() == 3, - "hybrid kv_cache_kernel_block_id must be 3-D"); - RTP_LLM_CHECK_WITH_INFO(model_input.kv_cache_block_id.dim() == 3, - "hybrid kv_cache_block_id must be 3-D"); - const size_t batch = model_input.kv_cache_kernel_block_id.size(1); - int32_t* kernel_dst_base = model_input.kv_cache_kernel_block_id.data_ptr(); - int32_t* store_dst_base = model_input.kv_cache_block_id.data_ptr(); - for (int gid = 0; gid < kv_cache.groupNums(); ++gid) { - auto& kernel_blocks = kv_cache.kernelBlocks(i, gid); - int32_t* kernel_dst = kernel_dst_base - + (static_cast(gid) * batch + static_cast(batch_idx)) - * max_blocks_num * kernel_blocks_per_kv_block_; - std::memcpy(kernel_dst, kernel_blocks.data(), kernel_blocks.size() * sizeof(int32_t)); - - auto& physical_blocks = kv_cache.blocks(i, gid); - int32_t* store_dst = - store_dst_base - + (static_cast(gid) * batch + static_cast(batch_idx)) * max_blocks_num; - std::memcpy(store_dst, physical_blocks.data(), physical_blocks.size() * sizeof(int32_t)); - } - } - batch_idx += 1; - } - - if (max_blocks_num) { - add_cache_update_copy(stream->streamCacheResource().getKVBlockUpdateMapping()); - } - - stream->step(); - } - - std::vector gathered_mm_features; - int token_idx = batch_idx; - int cum_output_seq_len = batch_idx; - int mm_feature_index = 0; - - for (const auto& stream : context_streams) { - // context stream也需要batch运行是为了perf test的场景 - model_input.need_all_logits = model_input.need_all_logits || stream->calculateLoss(); - auto current_batch_size = stream->currentBatchSize(); - - auto& kv_cache = *stream->kvCachePtr(); - if (enable_detail_log_) { - RTP_LLM_LOG_DEBUG("context kv_cache: %s", kv_cache.debugString().c_str()); - RTP_LLM_LOG_DEBUG("context stream: %s", stream->debugString().c_str()); - } else { - RTP_LLM_LOG_TRACE("context kv_cache: %s", kv_cache.debugString().c_str()); - RTP_LLM_LOG_TRACE("context stream: %s", stream->debugString().c_str()); - } - - // TODO(xinfei.sxf) deal with adjusted common seq len. - for (auto i = 0; i < current_batch_size; ++i) { - model_input.trace_ids.push_back(stream->traceId()); - - auto input_tokens = stream->currentExecuteTokens(i); - auto input_masks = stream->textTokensMask(); - memcpy(merged_tokens + token_idx, input_tokens.data(), input_tokens.size() * sizeof(int)); - cum_output_seq_len += input_tokens.size(); - - for (int index = 0; index < input_tokens.size(); ++index) { - if (input_tokens[index] >= input_vocab_size && (index >= input_masks.size() || input_masks[index])) { - std::ostringstream error_msg; - error_msg << "stream [" << stream->streamId() << "] token_id " << input_tokens[index] - << " exceed vocab_size " << input_vocab_size; - return absl::InvalidArgumentError(error_msg.str()); - } - } - - input_lengths[batch_idx] = input_tokens.size(); - prefix_lengths[batch_idx - total_decode_batch_size] = stream->prefixLength(); - lm_output_indexes[batch_idx] = cum_output_seq_len - 1; - lm_output_lengths[batch_idx] = 1; - - if (has_multimodal_input) { - std::vector mm_features = stream->multimodalFeatures(); - torch::Tensor mm_locs = stream->multimodalLocations(); - if (mm_locs.defined()) { - auto* mm_locs_data = mm_locs.data_ptr(); - for (int i = 0; i < mm_locs.numel(); ++i) { - mm_features_locs[mm_feature_index] = mm_locs_data[i] + token_idx - stream->reuseLength(); - mm_feature_index++; - } - for (auto& mm_feature : mm_features) { - if (!mm_feature.is_cuda()) { - gathered_mm_features.emplace_back(mm_feature.to(torch::kCUDA)); - } else { - gathered_mm_features.emplace_back(mm_feature); - } - } - auto text_token_mask = stream->textTokensMask(); - memcpy(merged_text_mask + token_idx, text_token_mask.data(), text_token_mask.size() * sizeof(int)); - } - } - - if (need_cal_position_id) { - auto context_pos_ids = stream->generateContextPositionIds(); - int reuse_offset = stream->reuseLength() * position_id_len_factor_; - memcpy(combo_position_ids + token_idx * position_id_len_factor_, - context_pos_ids.data_ptr() + reuse_offset, - (context_pos_ids.numel() - reuse_offset) * sizeof(int)); - } - if (max_blocks_num) { - RTP_LLM_CHECK_WITH_INFO(model_input.kv_cache_kernel_block_id.dim() == 3, - "hybrid kv_cache_kernel_block_id must be 3-D"); - RTP_LLM_CHECK_WITH_INFO(model_input.kv_cache_block_id.dim() == 3, - "hybrid kv_cache_block_id must be 3-D"); - const size_t batch = model_input.kv_cache_kernel_block_id.size(1); - int32_t* kernel_dst_base = model_input.kv_cache_kernel_block_id.data_ptr(); - int32_t* store_dst_base = model_input.kv_cache_block_id.data_ptr(); - for (int gid = 0; gid < kv_cache.groupNums(); ++gid) { - auto& kernel_blocks = kv_cache.kernelBlocks(i, gid); - int32_t* kernel_dst = kernel_dst_base - + (static_cast(gid) * batch + static_cast(batch_idx)) - * max_blocks_num * kernel_blocks_per_kv_block_; - std::memcpy(kernel_dst, kernel_blocks.data(), kernel_blocks.size() * sizeof(int32_t)); - - auto& physical_blocks = kv_cache.blocks(i, gid); - int32_t* store_dst = - store_dst_base - + (static_cast(gid) * batch + static_cast(batch_idx)) * max_blocks_num; - std::memcpy(store_dst, physical_blocks.data(), physical_blocks.size() * sizeof(int32_t)); - } - if (role_type_ == RoleType::PREFILL && stream->hasCacheKeys()) { - std::memcpy(model_input.cache_keys.data_ptr() - + (batch_idx - total_decode_batch_size) * model_input.cache_keys.size(1), - stream->cacheKeys(i).data(), - stream->cacheKeys(i).size() * sizeof(int64_t)); - } - } - *(model_input.request_id.data_ptr() + (batch_idx - total_decode_batch_size)) = stream->streamId(); - *(reinterpret_cast(model_input.request_pd_separation.data_ptr()) - + (batch_idx - total_decode_batch_size)) = stream->queryPdSep(); - batch_idx += 1; - token_idx += input_tokens.size(); - } - - if (max_blocks_num) { - add_cache_update_copy(stream->streamCacheResource().getKVBlockUpdateMapping()); - } +NormalBatchStreamProcessor::NormalBatchStreamProcessor( + const ModelConfig& model_config, + const PDSepConfig& pd_sep_config, + const ProfilingDebugLoggingConfig& profiling_debug_logging_config, + const CacheConfig& cache_config, + bool warm_up) { + model_input_gatherer_config_.num_layers = model_config.num_layers; + model_input_gatherer_config_.vocab_size = model_config.vocab_size; + model_input_gatherer_config_.input_vocab_size = model_config.input_vocab_size; + model_input_gatherer_config_.has_positional_encoding = model_config.has_positional_encoding; + model_input_gatherer_config_.is_multimodal = model_config.mm_model_config.is_multimodal; + model_input_gatherer_config_.mm_position_ids_style = + static_cast(model_config.mm_model_config.mm_position_ids_style); + model_input_gatherer_config_.position_id_len_factor = model_config.attn_config.rope_config.index_factor; + model_input_gatherer_config_.role_type = pd_sep_config.role_type; + model_input_gatherer_config_.decode_entrance = pd_sep_config.decode_entrance; + model_input_gatherer_config_.block_stride_bytes = cache_config.kv_block_stride_bytes; + model_input_gatherer_config_.scale_stride_bytes = cache_config.kv_scale_stride_bytes; + model_input_gatherer_config_.seq_size_per_block = cache_config.seq_size_per_block; + model_input_gatherer_config_.kernel_seq_size_per_block = cache_config.kernel_seq_size_per_block; + model_input_gatherer_config_.kernel_blocks_per_kv_block = cache_config.kernelBlocksPerKvBlock(); + model_input_gatherer_config_.kv_cache_group_nums = cache_config.groupNums(); + model_input_gatherer_config_.layer_to_kv_cache_group_id = cache_config.layer_to_group_id; + model_input_gatherer_config_.kv_cache_group_types = cache_config.group_types; + model_input_gatherer_config_.warm_up = warm_up; + model_input_gatherer_config_.enable_detail_log = profiling_debug_logging_config.enable_detail_log; + + model_input_gatherer_ = std::make_unique(model_input_gatherer_config_); + sampler_input_gatherer_ = std::make_unique(); + output_dispatcher_ = std::make_unique(); +} - stream->step(); - } +absl::Status NormalBatchStreamProcessor::dispatch(const StreamGroups& stream_groups, + const MergedOutput& merge_outputs) const { + return output_dispatcher_->dispatch(stream_groups, merge_outputs); +} - if (is_multimodal_ && gathered_mm_features.size() > 0) { - model_input.multimodal_features = std::move(gathered_mm_features); - } - return model_input; +absl::StatusOr NormalBatchStreamProcessor::gatherModelInput(const StreamGroups& stream_groups) const { + return model_input_gatherer_->gather(stream_groups); } absl::StatusOr NormalBatchStreamProcessor::gatherSamplerInput( const StreamGroups& stream_groups, const GptModelInputs& model_inputs, const GptModelOutputs& model_output) const { - RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); - RTP_LLM_CHECK(!stream_groups.empty()); - auto all_streams = stream_groups.allStreams(); - auto total_batch_size_in = stream_groups.totalSamplerBatchSizeIn(); - auto total_batch_size_out = stream_groups.totalSamplerBatchSizeOut(); - bool return_all_probs = stream_groups.needReturnAllProbs(); - - SamplerInputs sampler_inputs = - allocateSamplerInputs(stream_groups, total_batch_size_in, total_batch_size_out, model_inputs.sequence_lengths); - setCommonSamplerInputs(sampler_inputs, all_streams); - - setLogitsProcessorInputs(sampler_inputs, all_streams); - - size_t total_decode_batch_size_in = 0; - int batch_idx = 0; - bool return_logits = false; - bool calculate_softmax_probs = false; - bool need_tiling = false; - for (auto& stream : all_streams) { - auto complete_token_ids = stream->completeTokenIds(); - auto complete_seq_len = complete_token_ids.size(1); - auto seq_len = stream->seqLength(); - auto current_batch_size = stream->currentBatchSize(); - auto sampler_batch_size = - stream->needTilingForSampling() ? stream->nextBatchSize() : stream->currentBatchSize(); - - for (int i = 0; i < sampler_batch_size; ++i) { - int cur_batch = std::min(i, current_batch_size - 1); - memcpy(sampler_inputs.token_ids.data_ptr() + ((batch_idx) * (sampler_inputs.step + 1)), - complete_token_ids.data_ptr() + cur_batch * complete_seq_len, - seq_len * sizeof(int)); - reinterpret_cast(sampler_inputs.finished_mask.data_ptr())[batch_idx] = stream->isDoneWithoutLock(i); - batch_idx += 1; - } - need_tiling |= stream->needTilingForSampling(); - if (!stream->isContextStream()) { - total_decode_batch_size_in += sampler_batch_size; - } - return_logits |= stream->returnLogits(); - calculate_softmax_probs |= stream->calculateSoftmaxProbs(); - RTP_LLM_LOG_DEBUG("stream [%ld], sampler inputs token ids = [%s]", - stream->streamId(), - tensorDebugStringWithData(sampler_inputs.token_ids).c_str()); - } - - auto vocab_size = (size_t)model_output.logits.size(1); - sampler_inputs.vocab_size = vocab_size; - if (return_all_probs) { - sampler_inputs.all_probs = torch::zeros({(int64_t)total_batch_size_in, (int64_t)vocab_size}, - torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); - } - - // copy logits when needs tiling or returning logits - torch::Tensor logits_tensor; - if (need_tiling) { - logits_tensor = - torch::empty({(int64_t)total_batch_size_in, (int64_t)vocab_size}, model_output.logits.options()); - // copy decode batch logits - if (total_decode_batch_size_in > 0) { - logits_tensor.narrow(0, 0, total_decode_batch_size_in) - .copy_(model_output.logits.narrow(0, 0, total_decode_batch_size_in)); - } - // tile context batch logits - size_t input_offset = total_decode_batch_size_in, logits_offset = total_decode_batch_size_in; - for (auto& stream : stream_groups.contextStreams()) { - auto sampler_batch_size = - stream->needTilingForSampling() ? stream->nextBatchSize() : stream->currentBatchSize(); - for (int i = 0; i < sampler_batch_size; ++i) { - logits_tensor[input_offset].copy_(model_output.logits[logits_offset]); - input_offset += 1; - } - logits_offset += 1; - } - } else if (return_logits || calculate_softmax_probs) { - logits_tensor = model_output.logits.clone(); - } else { - logits_tensor = model_output.logits; - } - sampler_inputs.logits = logits_tensor; - - RTP_LLM_LOG_DEBUG("sampler inputs logits [%s]", - tensorDebugStringWithData(sampler_inputs.logits.cpu(), 10).c_str()); - - RTP_LLM_LOG_DEBUG("gatherSamplerInput done"); - return std::move(sampler_inputs); + return sampler_input_gatherer_->gather(stream_groups, model_inputs, model_output); } -SamplerInputs NormalBatchStreamProcessor::allocateSamplerInputs(const StreamGroups& stream_groups, - size_t total_batch_size_in, - size_t total_batch_size_out, - const torch::Tensor& sequence_lengths, - size_t propose_step) const { - // TODO(xinfei.sxf) don't sample for chunk stream - SamplerInputs sampler_inputs; - sampler_inputs.step = stream_groups.maxSeqLen() + propose_step; - sampler_inputs.batch_size = total_batch_size_in; - sampler_inputs.batch_size_out = total_batch_size_out; - auto bs = (int64_t)total_batch_size_in; - sampler_inputs.sequence_lengths = torch::empty({bs}, torch::kInt32); - sampler_inputs.logits_processor_states_ptr.reset(); - sampler_inputs.input_lengths = torch::empty({bs}, torch::kInt32); - sampler_inputs.num_beams_in = torch::empty({bs}, torch::kLong); - sampler_inputs.num_beams_out = torch::empty({bs}, torch::kLong); - static const auto pinned_int = torch::TensorOptions(torch::kInt).pinned_memory(true); - static const auto pinned_i32 = torch::TensorOptions(torch::kInt32).pinned_memory(true); - static const auto pinned_f32 = torch::TensorOptions(torch::kFloat32).pinned_memory(true); - static const auto pinned_bool = torch::TensorOptions(torch::kBool).pinned_memory(true); - - sampler_inputs.top_k = torch::empty({bs}, pinned_int); - sampler_inputs.top_p = torch::empty({bs}, pinned_f32); - sampler_inputs.temperature = torch::empty({bs}, pinned_f32); - sampler_inputs.repetition_penalty = torch::empty({bs}, pinned_f32); - sampler_inputs.presence_penalty = torch::empty({bs}, pinned_f32); - sampler_inputs.frequency_penalty = torch::empty({bs}, pinned_f32); - sampler_inputs.no_repeat_ngram_size = torch::empty({bs}, pinned_i32); - sampler_inputs.do_sample = torch::empty({bs}, pinned_bool); - sampler_inputs.finished_mask = torch::empty({bs}, torch::kBool); - if (stream_groups.needReturnCumLogProbs()) { - sampler_inputs.cum_log_probs = torch::empty({(int64_t)total_batch_size_in}, torch::kFloat32); - } - sampler_inputs.token_ids = - torch::empty({(int64_t)total_batch_size_in, (int64_t)(sampler_inputs.step + 1)}, torch::kInt32); - sampler_inputs.generator.resize(total_batch_size_in); - return sampler_inputs; +SamplerInputs NormalBatchStreamProcessor::allocateSamplerInputs(const StreamGroups& stream_groups, + size_t total_batch_size_in, + size_t total_batch_size_out, + size_t propose_step) const { + return sampler_input_gatherer_->allocateSamplerInputs( + stream_groups, total_batch_size_in, total_batch_size_out, propose_step); } -void NormalBatchStreamProcessor::setCommonSamplerInputs(SamplerInputs& sampler_inputs, - std::list& all_streams, - bool score_batch, - size_t propose_step) const { - int* input_lengths = sampler_inputs.input_lengths.data_ptr(); - int* sequence_lengths = sampler_inputs.sequence_lengths.data_ptr(); - uint64_t* num_beams_in = reinterpret_cast(sampler_inputs.num_beams_in.data_ptr()); - uint64_t* num_beams_out = reinterpret_cast(sampler_inputs.num_beams_out.data_ptr()); - uint32_t* top_k = reinterpret_cast(sampler_inputs.top_k.data_ptr()); - float* top_p = sampler_inputs.top_p.data_ptr(); - float* temperature = sampler_inputs.temperature.data_ptr(); - float* repetition_penalty = sampler_inputs.repetition_penalty.data_ptr(); - float* presence_penalty = sampler_inputs.presence_penalty.data_ptr(); - float* frequency_penalty = sampler_inputs.frequency_penalty.data_ptr(); - int32_t* no_repeat_ngram_size = sampler_inputs.no_repeat_ngram_size.data_ptr(); - bool* do_sample = reinterpret_cast(sampler_inputs.do_sample.data_ptr()); - - int batch_idx = 0; - for (auto& stream : all_streams) { - int sampler_batch_size; - if (score_batch) { - sampler_batch_size = stream->scoreLen(); - } else if (stream->needTilingForSampling()) { - sampler_batch_size = stream->nextBatchSize(); - } else { - sampler_batch_size = stream->currentBatchSize(); - } - if (sampler_inputs.cum_log_probs.defined()) { - const auto& cum_log_probs = stream->cumLogProbs(); - memcpy(sampler_inputs.cum_log_probs.data_ptr() + batch_idx, - cum_log_probs.data_ptr(), - cum_log_probs.numel() * sizeof(float)); - } - for (int i = 0; i < sampler_batch_size; ++i) { - input_lengths[batch_idx] = stream->inputLength(); - sequence_lengths[batch_idx] = stream->seqLength() + propose_step; - num_beams_in[batch_idx] = stream->currentNumBeams(); - num_beams_out[batch_idx] = stream->nextNumBeams(); - top_k[batch_idx] = stream->generateConfig()->top_k; - top_p[batch_idx] = stream->generateConfig()->top_p; - temperature[batch_idx] = stream->generateConfig()->temperature; - repetition_penalty[batch_idx] = stream->generateConfig()->repetition_penalty; - presence_penalty[batch_idx] = stream->generateConfig()->presence_penalty; - frequency_penalty[batch_idx] = stream->generateConfig()->frequency_penalty; - do_sample[batch_idx] = stream->generateConfig()->do_sample; - if (!do_sample[batch_idx]) { - top_k[batch_idx] = 1; - top_p[batch_idx] = 1; - temperature[batch_idx] = 1; - } - no_repeat_ngram_size[batch_idx] = stream->generateConfig()->no_repeat_ngram_size.value_or(0); - sampler_inputs.generator[batch_idx] = stream->getGenerator(); - batch_idx += 1; - } - } +void NormalBatchStreamProcessor::fillSamplerCommonInputs(SamplerInputs& sampler_inputs, + std::list& all_streams, + bool score_batch, + size_t propose_step) const { + sampler_input_gatherer_->fillSamplerCommonInputs(sampler_inputs, all_streams, score_batch, propose_step); } void NormalBatchStreamProcessor::setLogitsProcessorInputs(SamplerInputs& sampler_inputs, std::list& all_streams, bool score_batch) const { - LogitsProcessorStatesPtr state_ptr = std::make_shared(); - std::for_each(all_streams.begin(), all_streams.end(), [&state_ptr, idx = 0](auto& stream) mutable { - for (const auto& processor : stream->getAllLogitsProcessorPtr()) { - state_ptr->insert(processor, idx, idx + stream->currentBatchSize()); - } - idx += stream->currentBatchSize(); - }); - sampler_inputs.logits_processor_states_ptr = state_ptr; -} - -absl::Status NormalBatchStreamProcessor::dispatch(const StreamGroups& stream_groups, - const MergedOutput& merge_outputs) const { - RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); - const auto& sampler_output = merge_outputs.sampler_output; - const auto& new_all_token_ids = sampler_output.token_ids; - RTP_LLM_LOG_DEBUG("new_all_token_ids = [%s]", tensorDebugStringWithData(new_all_token_ids).c_str()); - const size_t total_batch_size_out = stream_groups.totalSamplerBatchSizeOut(); - RTP_LLM_CHECK(total_batch_size_out == (size_t)new_all_token_ids.size(0)); - int batch_idx_in = 0; - int batch_idx_out = 0; - int token_offset = 0; - bool return_all_probs = stream_groups.needReturnAllProbs(); - auto new_tokens_all = torch::empty({(int64_t)total_batch_size_out, 1}, torch::kInt32); - - for (auto& stream : stream_groups.allStreams()) { - auto cur_batch_size = stream->currentBatchSize(); - auto next_batch_size = stream->nextBatchSize(); - auto token_size = stream->currentExecuteTokenSize(); - - dispatchSingleStream( - stream, merge_outputs, batch_idx_in, batch_idx_out, token_offset, return_all_probs, new_tokens_all); - - batch_idx_in += cur_batch_size; - batch_idx_out += next_batch_size; - token_offset += token_size; - } - - RTP_LLM_LOG_DEBUG("dispatch done"); - return absl::OkStatus(); -} - -void NormalBatchStreamProcessor::dispatchSingleStream(GenerateStreamPtr stream, - const MergedOutput& merge_outputs, - int batch_idx_in, - int batch_idx_out, - int token_offset, - bool return_all_probs, - const torch::Tensor& new_tokens_all) const { - - const auto& model_output = merge_outputs.model_output; - const auto& sampler_output = merge_outputs.sampler_output; - const auto& new_all_token_ids = sampler_output.token_ids; - const size_t token_stride = new_all_token_ids.size(1); - - auto cur_batch_size = stream->currentBatchSize(); - auto next_batch_size = stream->nextBatchSize(); - auto token_size = stream->currentExecuteTokenSize(); - - auto batch_new_all_token_ids = new_all_token_ids.narrow(0, batch_idx_out, next_batch_size); - - bool has_beam_search = stream->currentNumBeams() > 1 || stream->nextNumBeams() > 1; - bool has_var_batch = stream->currentBatchSize() != stream->nextBatchSize(); - - // construct mapping from output batches to input batches - torch::Tensor src_batch_indices; - if (has_beam_search) { - // beam search - src_batch_indices = sampler_output.beam_index.narrow(0, batch_idx_out, next_batch_size); - } else if (has_var_batch) { - // from context stream to decode straem, there might be other cases in future - src_batch_indices = torch::zeros({(int64_t)next_batch_size}, torch::kInt32); - } - const auto get_src_idx = [&](int32_t dst_idx) { - return src_batch_indices.defined() ? src_batch_indices.data_ptr()[dst_idx] : dst_idx; - }; - - // construct update info - torch::Tensor batch_hidden_states; - if (stream->generateConfig()->return_hidden_states) { - batch_hidden_states = model_output.hidden_states.narrow(0, batch_idx_in, cur_batch_size); - } - - torch::Tensor batch_logits; - if (stream->returnLogits() || stream->calculateSoftmaxProbs() || has_beam_search) { - batch_logits = model_output.logits.narrow(0, batch_idx_in, cur_batch_size); - } - - torch::Tensor all_probs; - if (return_all_probs) { - all_probs = sampler_output.all_probs.narrow(0, batch_idx_out, next_batch_size); - }; - - torch::Tensor batch_cum_log_probs; - if (sampler_output.cum_log_probs.defined()) { - batch_cum_log_probs = sampler_output.cum_log_probs.narrow(0, batch_idx_out, next_batch_size); - } - - torch::Tensor loss; - if (stream->calculateLoss()) { - auto all_logits_tensor = model_output.all_logits.narrow(0, token_offset, token_size - 1); - auto tokens = stream->currentExecuteTokens(0); - auto label_tensor = - torch::from_blob(const_cast(tokens.data() + 1), {(int64_t)(tokens.size() - 1)}, torch::kInt32) - .to(torch::kCUDA); - auto labels_int64 = label_tensor.toType(torch::kInt64); - loss = torch::cross_entropy_loss(all_logits_tensor, labels_int64, torch::nullopt, at::Reduction::None) - .to(torch::kFloat32); - } - - torch::Tensor all_hidden_states; - if (stream->needReturnHiddenStates()) { - all_hidden_states = model_output.all_hidden_states.narrow(0, token_offset, token_size); - } - - auto new_tokens = new_tokens_all.narrow(0, batch_idx_out, next_batch_size); - for (size_t i = 0; i < next_batch_size; ++i) { - new_tokens.data_ptr()[i] = - new_all_token_ids.data_ptr()[(batch_idx_out + i) * token_stride + token_stride - 1]; - } - - torch::Tensor current_softmax_result; - if (stream->calculateSoftmaxProbs()) { - auto batch_softmax_input = batch_logits.to(torch::kFloat32).contiguous(); -#if USING_CUDA - cudaSoftmaxInplace(batch_softmax_input, at::cuda::getCurrentCUDAStream().stream()); -#else - batch_softmax_input = torch::softmax(batch_softmax_input, -1); -#endif - auto batch_softmax_tensor = batch_softmax_input.cpu(); - current_softmax_result = torch::empty({(int64_t)next_batch_size, 1}, torch::kFloat32); - for (int i = 0; i < next_batch_size; ++i) { - current_softmax_result[i][0] = batch_softmax_tensor[get_src_idx(i)][new_tokens.data_ptr()[i]]; - } - } - - for (int i = 0; i < cur_batch_size; ++i) { - if (sampler_output.success.defined() && !(sampler_output.success.data_ptr()[batch_idx_in + i])) { - stream->setStop(ErrorCode::UNKNOWN_ERROR, "sampler generate token id failed"); - } - } - - RTP_LLM_LOG_DEBUG("stream [%ld], new_tokens size = [%ld]", stream->streamId(), new_tokens.numel()); - - stream->update({has_beam_search ? batch_new_all_token_ids : new_tokens, - 1, - batch_hidden_states, - batch_logits, - current_softmax_result, - batch_cum_log_probs, - all_probs, - loss, - src_batch_indices, - all_hidden_states}); + sampler_input_gatherer_->setLogitsProcessorInputs(sampler_inputs, all_streams, score_batch); } } // namespace rtp_llm diff --git a/rtp_llm/cpp/normal_engine/NormalBatchStreamProcessor.h b/rtp_llm/cpp/normal_engine/NormalBatchStreamProcessor.h index ce591fa8e5..d7221807a0 100644 --- a/rtp_llm/cpp/normal_engine/NormalBatchStreamProcessor.h +++ b/rtp_llm/cpp/normal_engine/NormalBatchStreamProcessor.h @@ -1,12 +1,18 @@ #pragma once +#include #include -#include "rtp_llm/cpp/cache/CacheGroupType.h" + +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "rtp_llm/cpp/cache/CacheConfig.h" #include "rtp_llm/cpp/config/ConfigModules.h" -#include "rtp_llm/cpp/models/SampleInfos.h" #include "rtp_llm/cpp/engine_base/stream/StreamGroups.h" -#include "absl/status/statusor.h" -#include "absl/status/status.h" +#include "rtp_llm/cpp/models/SampleInfos.h" +#include "rtp_llm/cpp/normal_engine/NormalModelInputGatherer.h" +#include "rtp_llm/cpp/normal_engine/NormalOutputDispatcher.h" +#include "rtp_llm/cpp/normal_engine/NormalSamplerInputGatherer.h" namespace rtp_llm { @@ -16,27 +22,7 @@ class NormalBatchStreamProcessor { const PDSepConfig& pd_sep_config, const ProfilingDebugLoggingConfig& profiling_debug_logging_config, const CacheConfig& cache_config, - bool warm_up): - num_layers_(model_config.num_layers), - vocab_size_(model_config.vocab_size), - input_vocab_size_(model_config.input_vocab_size), - use_int8_kv_cache_(model_config.attn_config.kv_cache_dtype == rtp_llm::KvCacheDataType::INT8), - has_positional_encoding_(model_config.has_positional_encoding), - is_multimodal_(model_config.mm_model_config.is_multimodal), - mm_position_ids_style_((PositionIdsStyle)model_config.mm_model_config.mm_position_ids_style), - position_id_len_factor_(model_config.attn_config.rope_config.index_factor), - role_type_(pd_sep_config.role_type), - decode_entrance_(pd_sep_config.decode_entrance), - block_stride_bytes_(cache_config.kv_block_stride_bytes), - scale_stride_bytes_(cache_config.kv_scale_stride_bytes), - seq_size_per_block_(cache_config.seq_size_per_block), - kernel_seq_size_per_block_(cache_config.kernel_seq_size_per_block), - kernel_blocks_per_kv_block_(cache_config.kernelBlocksPerKvBlock()), - kv_cache_group_nums_(cache_config.groupNums()), - layer_to_kv_cache_group_id_(cache_config.layer_to_group_id), - kv_cache_group_types_(cache_config.group_types), - warm_up_(warm_up), - enable_detail_log_(profiling_debug_logging_config.enable_detail_log) {} + bool warm_up); virtual absl::Status dispatch(const StreamGroups& stream_groups, const MergedOutput& merge_outputs) const; virtual absl::StatusOr gatherModelInput(const StreamGroups& stream_groups) const; @@ -45,52 +31,32 @@ class NormalBatchStreamProcessor { const GptModelOutputs& model_output) const; protected: - SamplerInputs allocateSamplerInputs(const StreamGroups& stream_groups, - size_t total_batch_size_in, - size_t total_batch_size_out, - const torch::Tensor& sequence_length, - size_t propose_step = 0) const; - void setCommonSamplerInputs(SamplerInputs& sampler_inputs, - std::list& all_streams, - bool score_batch = false, - size_t propose_step = 0) const; - void setLogitsProcessorInputs(SamplerInputs& sampler_inputs, - std::list& all_streams, - bool score_batch = false) const; + SamplerInputs allocateSamplerInputs(const StreamGroups& stream_groups, + size_t total_batch_size_in, + size_t total_batch_size_out, + size_t propose_step = 0) const; + + void setCommonSamplerInputs(SamplerInputs& sampler_inputs, + std::list& all_streams, + bool score_batch = false, + size_t propose_step = 0) const { + fillSamplerCommonInputs(sampler_inputs, all_streams, score_batch, propose_step); + } - void dispatchSingleStream(GenerateStreamPtr stream, - const MergedOutput& merge_outputs, - int batch_idx_in, - int batch_idx_out, - int token_offset, - bool return_all_probs, - const torch::Tensor& new_tokens_all) const; + void fillSamplerCommonInputs(SamplerInputs& sampler_inputs, + std::list& all_streams, + bool score_batch = false, + size_t propose_step = 0) const; - void setKVCacheGroupTypes(std::vector kv_cache_group_types) { - kv_cache_group_types_ = kv_cache_group_types; - } + void setLogitsProcessorInputs(SamplerInputs& sampler_inputs, + std::list& all_streams, + bool score_batch = false) const; protected: - size_t num_layers_; - size_t vocab_size_; - size_t input_vocab_size_; - bool use_int8_kv_cache_; - bool has_positional_encoding_; - bool is_multimodal_; - PositionIdsStyle mm_position_ids_style_; - size_t position_id_len_factor_; - RoleType role_type_; - bool decode_entrance_; - size_t block_stride_bytes_; - size_t scale_stride_bytes_; - size_t seq_size_per_block_; - size_t kernel_seq_size_per_block_; - size_t kernel_blocks_per_kv_block_ = 1; - size_t kv_cache_group_nums_ = 1; - mutable std::vector layer_to_kv_cache_group_id_; - std::vector kv_cache_group_types_; - bool warm_up_; - bool enable_detail_log_; + NormalModelInputGathererConfig model_input_gatherer_config_; + std::unique_ptr model_input_gatherer_; + std::unique_ptr sampler_input_gatherer_; + std::unique_ptr output_dispatcher_; }; } // namespace rtp_llm diff --git a/rtp_llm/cpp/normal_engine/NormalModelInputGatherer.cc b/rtp_llm/cpp/normal_engine/NormalModelInputGatherer.cc new file mode 100644 index 0000000000..914c74d249 --- /dev/null +++ b/rtp_llm/cpp/normal_engine/NormalModelInputGatherer.cc @@ -0,0 +1,361 @@ +#include +#include +#include +#include "torch/all.h" +#include "rtp_llm/cpp/cache/Types.h" +#include "rtp_llm/cpp/normal_engine/NormalModelInputGatherer.h" +#include "rtp_llm/cpp/utils/AssertUtils.h" +#include "rtp_llm/cpp/utils/StatusUtil.h" + +namespace rtp_llm { + +namespace { + +struct GatherModelInputContext { + int input_vocab_size; + bool need_cal_position_id; + size_t max_blocks_num; + int* merged_tokens; + int* input_lengths; + int* lm_output_indexes; + int* lm_output_lengths; + int* combo_position_ids; + BlockIdPair* kv_cache_update_mapping; + int batch_idx; + int* sequence_lengths; + bool has_multimodal_input; + size_t total_decode_batch_size; + int* prefix_lengths; + int* merged_text_mask; + int* mm_features_locs; + int token_idx; + int cum_output_seq_len; + int mm_feature_index; +}; + +enum class GatherContextMode { + DECODE, + CONTEXT +}; + +GatherModelInputContext createGatherContext(const NormalModelInputGathererConfig& config, + GptModelInputs& model_input, + const StreamGroups& stream_groups, + GatherContextMode mode) { + GatherModelInputContext ctx{}; + ctx.input_vocab_size = + config.input_vocab_size ? static_cast(config.input_vocab_size) : static_cast(config.vocab_size); + ctx.need_cal_position_id = + (config.mm_position_ids_style != PositionIdsStyle::DEFAULT) || config.has_positional_encoding; + ctx.max_blocks_num = stream_groups.curBlocksNum(); + ctx.merged_tokens = model_input.combo_tokens.data_ptr(); + ctx.input_lengths = model_input.input_lengths.data_ptr(); + ctx.sequence_lengths = model_input.sequence_lengths.data_ptr(); + ctx.lm_output_indexes = model_input.lm_output_indexes.data_ptr(); + ctx.lm_output_lengths = model_input.lm_output_lengths.data_ptr(); + ctx.combo_position_ids = ctx.need_cal_position_id ? model_input.combo_position_ids.data_ptr() : nullptr; + ctx.has_multimodal_input = config.is_multimodal && stream_groups.has_multimodal_input(); + ctx.prefix_lengths = model_input.prefix_lengths.data_ptr(); + ctx.merged_text_mask = ctx.has_multimodal_input ? model_input.text_tokens_mask.data_ptr() : nullptr; + ctx.mm_features_locs = ctx.has_multimodal_input ? model_input.mm_features_locs.data_ptr() : nullptr; + + size_t kv_cache_mapping_offset = 0; + if (mode == GatherContextMode::DECODE) { + ctx.batch_idx = 0; + } else { + ctx.total_decode_batch_size = stream_groups.totalDecodeBatchSize(); + ctx.batch_idx = static_cast(ctx.total_decode_batch_size); + ctx.token_idx = ctx.batch_idx; + ctx.cum_output_seq_len = ctx.batch_idx; + ctx.mm_feature_index = 0; + kv_cache_mapping_offset = stream_groups.decodeBlockUpdateCopyNum(); + } + ctx.kv_cache_update_mapping = + model_input.kv_cache_update_mapping.defined() ? + reinterpret_cast(model_input.kv_cache_update_mapping.data_ptr()) + kv_cache_mapping_offset : + nullptr; + + if (ctx.merged_text_mask) { + size_t current_tokens_size = stream_groups.modelExecuteTokenSize(); + std::fill(ctx.merged_text_mask, ctx.merged_text_mask + current_tokens_size, 1); + } + + return ctx; +} + +void copyKvCacheBlocksToModelInput(GptModelInputs& model_input, + const BatchKVCacheResource& kv_cache, + int stream_batch_idx, + int model_batch_idx, + size_t max_blocks_num, + size_t kernel_blocks_per_kv_block) { + if (!model_input.kv_cache_kernel_block_id.defined() || max_blocks_num == 0) { + return; + } + RTP_LLM_CHECK_WITH_INFO(model_input.kv_cache_kernel_block_id.dim() == 3, + "hybrid kv_cache_kernel_block_id must be 3-D"); + RTP_LLM_CHECK_WITH_INFO(model_input.kv_cache_block_id.dim() == 3, "hybrid kv_cache_block_id must be 3-D"); + + const size_t batch = model_input.kv_cache_kernel_block_id.size(1); + int32_t* kernel_dst_base = model_input.kv_cache_kernel_block_id.data_ptr(); + int32_t* store_dst_base = model_input.kv_cache_block_id.data_ptr(); + + for (int gid = 0; gid < kv_cache.groupNums(); ++gid) { + auto& kernel_blocks = kv_cache.kernelBlocks(stream_batch_idx, gid); + int32_t* kernel_dst = kernel_dst_base + + (static_cast(gid) * batch + static_cast(model_batch_idx)) + * max_blocks_num * kernel_blocks_per_kv_block; + std::memcpy(kernel_dst, kernel_blocks.data(), kernel_blocks.size() * sizeof(int32_t)); + + auto& physical_blocks = kv_cache.blocks(stream_batch_idx, gid); + int32_t* store_dst = + store_dst_base + (static_cast(gid) * batch + static_cast(model_batch_idx)) * max_blocks_num; + std::memcpy(store_dst, physical_blocks.data(), physical_blocks.size() * sizeof(int32_t)); + } +} + +void gatherMultimodalFeaturesForContextBatch(const GenerateStreamPtr& stream, + GatherModelInputContext& ctx, + std::vector& gathered_mm_features) { + if (!ctx.has_multimodal_input) { + return; + } + std::vector mm_features = stream->multimodalFeatures(); + torch::Tensor mm_locs = stream->multimodalLocations(); + if (!mm_locs.defined()) { + return; + } + auto* mm_locs_data = mm_locs.data_ptr(); + for (int i = 0; i < mm_locs.numel(); ++i) { + ctx.mm_features_locs[ctx.mm_feature_index] = mm_locs_data[i] + ctx.token_idx - stream->reuseLength(); + ctx.mm_feature_index++; + } + for (auto& mm_feature : mm_features) { + if (!mm_feature.is_cuda()) { + gathered_mm_features.emplace_back(mm_feature.to(torch::kCUDA)); + } else { + gathered_mm_features.emplace_back(mm_feature); + } + } + auto text_token_mask = stream->textTokensMask(); + memcpy(ctx.merged_text_mask + ctx.token_idx, text_token_mask.data(), text_token_mask.size() * sizeof(int)); +} + +void addCacheUpdateCopy(GatherModelInputContext& ctx, const std::vector& update_mapping) { + if (!ctx.kv_cache_update_mapping) { + return; + } + size_t update_copy_num = update_mapping.size(); + std::memcpy(ctx.kv_cache_update_mapping, update_mapping.data(), update_copy_num * sizeof(BlockIdPair)); + ctx.kv_cache_update_mapping += update_copy_num; +} + +} // anonymous namespace + +NormalModelInputGatherer::NormalModelInputGatherer(const NormalModelInputGathererConfig& config): config_(config) {} + +GptModelInputs NormalModelInputGatherer::allocateModelInputBuffers(const StreamGroups& stream_groups) const { + const size_t current_tokens_size = stream_groups.modelExecuteTokenSize(); + const size_t total_batch_size = stream_groups.totalModelBatchSize(); + const size_t total_decode_batch_size = stream_groups.totalDecodeBatchSize(); + const size_t total_context_batch_size = stream_groups.totalContextBatchSize(); + const size_t total_block_copy_num = stream_groups.totalBlockUpdateCopyNum(); + const size_t max_blocks_num = stream_groups.curBlocksNum(); + const size_t multimodal_features_len = stream_groups.mmFeaturesLen(); + const bool has_multimodal_input = config_.is_multimodal && stream_groups.has_multimodal_input(); + const bool need_cal_position_id = + (config_.mm_position_ids_style != PositionIdsStyle::DEFAULT) || config_.has_positional_encoding; + + static const auto pinned_i32 = torch::TensorOptions(torch::kInt32).pinned_memory(true); + static const auto pinned_i64 = torch::TensorOptions(torch::kInt64).pinned_memory(true); + static const auto pinned_bool = torch::TensorOptions(torch::kBool).pinned_memory(true); + + GptModelInputs model_input; + model_input.combo_tokens = torch::empty({(int64_t)current_tokens_size}, pinned_i32); + model_input.input_lengths = torch::empty({(int64_t)total_batch_size}, pinned_i32); + model_input.sequence_lengths = torch::empty({(int64_t)total_decode_batch_size}, pinned_i32); + model_input.lm_output_indexes = torch::empty({(int64_t)total_batch_size}, pinned_i32); + model_input.lm_output_lengths = torch::empty({(int64_t)total_batch_size}, pinned_i32); + model_input.prefix_lengths = torch::empty({(int64_t)total_context_batch_size}, pinned_i32); + model_input.request_id = torch::empty({(int64_t)total_context_batch_size}, pinned_i64); + model_input.request_pd_separation = torch::empty({(int64_t)total_context_batch_size}, pinned_bool); + + if (max_blocks_num) { + model_input.kv_cache_kernel_block_id = + torch::zeros({(int64_t)config_.kv_cache_group_nums, + (int64_t)total_batch_size, + (int64_t)(max_blocks_num * config_.kernel_blocks_per_kv_block)}, + pinned_i32); + model_input.kv_cache_block_id = torch::zeros( + {(int64_t)config_.kv_cache_group_nums, (int64_t)total_batch_size, (int64_t)max_blocks_num}, pinned_i32); + model_input.kv_cache_layer_to_group = torch::empty({(int64_t)config_.num_layers}, pinned_i32); + model_input.kv_cache_group_types = torch::empty({(int64_t)config_.kv_cache_group_nums}, pinned_i32); + model_input.kv_cache_update_mapping = torch::empty({(int64_t)total_block_copy_num, 2}, pinned_i32); + model_input.cache_keys = torch::empty({(int64_t)total_context_batch_size, (int64_t)max_blocks_num}, pinned_i64); + } + + if (need_cal_position_id) { + model_input.combo_position_ids = + torch::empty({(int64_t)(current_tokens_size * config_.position_id_len_factor)}, pinned_i32); + } + if (has_multimodal_input) { + model_input.text_tokens_mask = torch::empty({(int64_t)current_tokens_size}, pinned_i32); + model_input.mm_features_locs = torch::empty({(int64_t)multimodal_features_len}, pinned_i32); + } + + model_input.kv_block_stride_bytes = config_.block_stride_bytes; + model_input.kv_scale_stride_bytes = config_.scale_stride_bytes; + model_input.seq_size_per_block = config_.seq_size_per_block; + model_input.kernel_seq_size_per_block = config_.kernel_seq_size_per_block; + model_input.pd_separation = config_.role_type == RoleType::PREFILL; + model_input.warmup = config_.warm_up; + model_input.decode_entrance = config_.decode_entrance; + model_input.is_fake_stream = stream_groups.isFakeStream(); + + return model_input; +} + +void NormalModelInputGatherer::initializeKvCacheMetadata(GptModelInputs& model_input) const { + if (model_input.kv_cache_layer_to_group.defined()) { + size_t num_layers = config_.layer_to_kv_cache_group_id.size(); + std::memcpy(model_input.kv_cache_layer_to_group.data_ptr(), + config_.layer_to_kv_cache_group_id.data(), + num_layers * sizeof(int32_t)); + } + if (model_input.kv_cache_group_types.defined()) { + auto* dst = model_input.kv_cache_group_types.data_ptr(); + for (size_t g = 0; g < config_.kv_cache_group_nums; ++g) { + dst[g] = static_cast(config_.kv_cache_group_types[g]); + } + } +} + +absl::Status NormalModelInputGatherer::processDecodeStreams(GptModelInputs& model_input, + const StreamGroups& stream_groups) const { + auto ctx = createGatherContext(config_, model_input, stream_groups, GatherContextMode::DECODE); + + for (const auto& stream : stream_groups.decodeStreams()) { + model_input.need_all_logits = model_input.need_all_logits || stream->calculateLoss(); + auto current_batch_size = stream->currentBatchSize(); + auto& kv_cache = *stream->kvCachePtr(); + RTP_LLM_LOG_DEBUG("decode kv_cache: %s", kv_cache.debugString().c_str()); + RTP_LLM_LOG_DEBUG("decode stream: %s", stream->debugString().c_str()); + + for (auto i = 0; i < current_batch_size; ++i) { + model_input.trace_ids.push_back(stream->traceId()); + auto currentTokens = stream->currentExecuteTokens(i); + if (currentTokens[0] >= ctx.input_vocab_size) { + std::ostringstream error_msg; + error_msg << "stream [" << stream->streamId() << "] token_id " << currentTokens[0] + << " exceed vocab_size " << ctx.input_vocab_size; + return absl::InvalidArgumentError(error_msg.str()); + } + ctx.merged_tokens[ctx.batch_idx] = currentTokens[0]; + ctx.input_lengths[ctx.batch_idx] = stream->inputLength(); + ctx.sequence_lengths[ctx.batch_idx] = stream->seqLength() - 1; + if (ctx.need_cal_position_id) { + stream->generateNextPositionId(ctx.combo_position_ids + ctx.batch_idx * config_.position_id_len_factor); + } + ctx.lm_output_indexes[ctx.batch_idx] = ctx.batch_idx; + ctx.lm_output_lengths[ctx.batch_idx] = 1; + copyKvCacheBlocksToModelInput( + model_input, kv_cache, i, ctx.batch_idx, ctx.max_blocks_num, config_.kernel_blocks_per_kv_block); + ctx.batch_idx += 1; + } + addCacheUpdateCopy(ctx, stream->streamCacheResource().getKVBlockUpdateMapping()); + stream->step(); + } + return absl::OkStatus(); +} + +absl::Status NormalModelInputGatherer::processContextStreams(GptModelInputs& model_input, + const StreamGroups& stream_groups) const { + std::vector gathered_mm_features; + auto ctx = createGatherContext(config_, model_input, stream_groups, GatherContextMode::CONTEXT); + + for (const auto& stream : stream_groups.contextStreams()) { + model_input.need_all_logits = model_input.need_all_logits || stream->calculateLoss(); + auto current_batch_size = stream->currentBatchSize(); + auto& kv_cache = *stream->kvCachePtr(); + if (config_.enable_detail_log) { + RTP_LLM_LOG_DEBUG("context kv_cache: %s", kv_cache.debugString().c_str()); + RTP_LLM_LOG_DEBUG("context stream: %s", stream->debugString().c_str()); + } else { + RTP_LLM_LOG_TRACE("context kv_cache: %s", kv_cache.debugString().c_str()); + RTP_LLM_LOG_TRACE("context stream: %s", stream->debugString().c_str()); + } + + for (auto i = 0; i < current_batch_size; ++i) { + const auto prefill_batch_idx = ctx.batch_idx - ctx.total_decode_batch_size; + model_input.trace_ids.push_back(stream->traceId()); + auto input_tokens = stream->currentExecuteTokens(i); + auto input_masks = stream->textTokensMask(); + memcpy(ctx.merged_tokens + ctx.token_idx, input_tokens.data(), input_tokens.size() * sizeof(int)); + ctx.cum_output_seq_len += input_tokens.size(); + + for (int index = 0; index < (int)input_tokens.size(); ++index) { + if (input_tokens[index] >= ctx.input_vocab_size + && (index >= (int)input_masks.size() || input_masks[index])) { + std::ostringstream error_msg; + error_msg << "stream [" << stream->streamId() << "] token_id " << input_tokens[index] + << " exceed vocab_size " << ctx.input_vocab_size; + return absl::InvalidArgumentError(error_msg.str()); + } + } + + ctx.input_lengths[ctx.batch_idx] = input_tokens.size(); + ctx.prefix_lengths[prefill_batch_idx] = stream->prefixLength(); + ctx.lm_output_indexes[ctx.batch_idx] = ctx.cum_output_seq_len - 1; + ctx.lm_output_lengths[ctx.batch_idx] = 1; + + gatherMultimodalFeaturesForContextBatch(stream, ctx, gathered_mm_features); + + if (ctx.need_cal_position_id) { + auto context_pos_ids = stream->generateContextPositionIds(); + int reuse_offset = stream->reuseLength() * config_.position_id_len_factor; + memcpy(ctx.combo_position_ids + ctx.token_idx * config_.position_id_len_factor, + context_pos_ids.data_ptr() + reuse_offset, + (context_pos_ids.numel() - reuse_offset) * sizeof(int)); + } + + copyKvCacheBlocksToModelInput( + model_input, kv_cache, i, ctx.batch_idx, ctx.max_blocks_num, config_.kernel_blocks_per_kv_block); + + if (ctx.max_blocks_num && config_.role_type == RoleType::PREFILL && stream->hasCacheKeys()) { + std::memcpy(model_input.cache_keys.data_ptr() + + prefill_batch_idx * model_input.cache_keys.size(1), + stream->cacheKeys(i).data(), + stream->cacheKeys(i).size() * sizeof(int64_t)); + } + + *(model_input.request_id.data_ptr() + prefill_batch_idx) = stream->streamId(); + *(reinterpret_cast(model_input.request_pd_separation.data_ptr()) + prefill_batch_idx) = + stream->queryPdSep(); + + ctx.batch_idx += 1; + ctx.token_idx += input_tokens.size(); + } + + addCacheUpdateCopy(ctx, stream->streamCacheResource().getKVBlockUpdateMapping()); + stream->step(); + } + + if (config_.is_multimodal && !gathered_mm_features.empty()) { + model_input.multimodal_features = std::move(gathered_mm_features); + } + return absl::OkStatus(); +} + +absl::StatusOr NormalModelInputGatherer::gather(const StreamGroups& stream_groups) const { + RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + RTP_LLM_LOG_DEBUG("context_streams size = %d, decode_streams size = %d", + stream_groups.contextStreams().size(), + stream_groups.decodeStreams().size()); + auto model_input = allocateModelInputBuffers(stream_groups); + initializeKvCacheMetadata(model_input); + RETURN_IF_STATUS_ERROR(processDecodeStreams(model_input, stream_groups)); + RETURN_IF_STATUS_ERROR(processContextStreams(model_input, stream_groups)); + return model_input; +} + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/normal_engine/NormalModelInputGatherer.h b/rtp_llm/cpp/normal_engine/NormalModelInputGatherer.h new file mode 100644 index 0000000000..f257376394 --- /dev/null +++ b/rtp_llm/cpp/normal_engine/NormalModelInputGatherer.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "rtp_llm/cpp/cache/CacheGroupType.h" +#include "rtp_llm/cpp/cache/Types.h" +#include "rtp_llm/cpp/config/ConfigModules.h" +#include "rtp_llm/cpp/core/OpData.h" +#include "rtp_llm/cpp/engine_base/stream/StreamGroups.h" +#include "rtp_llm/cpp/models/position_ids/PositionIdsGenerator.h" + +namespace rtp_llm { + +struct NormalModelInputGathererConfig { + size_t num_layers{}; + size_t vocab_size{}; + size_t input_vocab_size{}; + bool has_positional_encoding{}; + bool is_multimodal{}; + PositionIdsStyle mm_position_ids_style{}; + size_t position_id_len_factor{}; + RoleType role_type{}; + bool decode_entrance{}; + size_t block_stride_bytes{}; + size_t scale_stride_bytes{}; + size_t seq_size_per_block{}; + size_t kernel_seq_size_per_block{}; + size_t kernel_blocks_per_kv_block = 1; + size_t kv_cache_group_nums = 1; + std::vector layer_to_kv_cache_group_id; + std::vector kv_cache_group_types; + bool warm_up{}; + bool enable_detail_log{}; +}; + +class NormalModelInputGatherer { +public: + explicit NormalModelInputGatherer(const NormalModelInputGathererConfig& config); + + absl::StatusOr gather(const StreamGroups& stream_groups) const; + +private: + GptModelInputs allocateModelInputBuffers(const StreamGroups& stream_groups) const; + void initializeKvCacheMetadata(GptModelInputs& model_input) const; + absl::Status processDecodeStreams(GptModelInputs& model_input, const StreamGroups& stream_groups) const; + absl::Status processContextStreams(GptModelInputs& model_input, const StreamGroups& stream_groups) const; + + NormalModelInputGathererConfig config_; +}; + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/normal_engine/NormalOutputDispatcher.cc b/rtp_llm/cpp/normal_engine/NormalOutputDispatcher.cc new file mode 100644 index 0000000000..35655fba1d --- /dev/null +++ b/rtp_llm/cpp/normal_engine/NormalOutputDispatcher.cc @@ -0,0 +1,158 @@ +#include "rtp_llm/cpp/normal_engine/NormalOutputDispatcher.h" +#include "rtp_llm/cpp/engine_base/stream/GenerateStream.h" +#include "rtp_llm/cpp/utils/AssertUtils.h" +#include "rtp_llm/cpp/utils/TensorDebugUtils.h" +#include "rtp_llm/cpp/utils/ErrorCode.h" +#if USING_CUDA +#include "rtp_llm/cpp/cuda/ops/StandaloneOps.h" +#include "ATen/cuda/CUDAContext.h" +#endif + +namespace rtp_llm { + +absl::Status NormalOutputDispatcher::dispatch(const StreamGroups& stream_groups, + const MergedOutput& merge_outputs) const { + RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + const auto& sampler_output = merge_outputs.sampler_output; + const auto& new_all_token_ids = sampler_output.token_ids; + RTP_LLM_LOG_DEBUG("new_all_token_ids = [%s]", tensorDebugStringWithData(new_all_token_ids).c_str()); + const size_t total_batch_size_out = stream_groups.totalSamplerBatchSizeOut(); + RTP_LLM_CHECK(total_batch_size_out == (size_t)new_all_token_ids.size(0)); + int batch_idx_in = 0; + int batch_idx_out = 0; + int token_offset = 0; + bool return_all_probs = stream_groups.needReturnAllProbs(); + auto new_tokens_all = torch::empty({(int64_t)total_batch_size_out, 1}, torch::kInt32); + + for (auto& stream : stream_groups.allStreams()) { + auto cur_batch_size = stream->currentBatchSize(); + auto next_batch_size = stream->nextBatchSize(); + auto token_size = stream->currentExecuteTokenSize(); + + dispatchSingleStream( + stream, merge_outputs, batch_idx_in, batch_idx_out, token_offset, return_all_probs, new_tokens_all); + + batch_idx_in += cur_batch_size; + batch_idx_out += next_batch_size; + token_offset += token_size; + } + + RTP_LLM_LOG_DEBUG("dispatch done"); + return absl::OkStatus(); +} + +void NormalOutputDispatcher::dispatchSingleStream(GenerateStreamPtr stream, + const MergedOutput& merge_outputs, + int batch_idx_in, + int batch_idx_out, + int token_offset, + bool return_all_probs, + const torch::Tensor& new_tokens_all) const { + + const auto& model_output = merge_outputs.model_output; + const auto& sampler_output = merge_outputs.sampler_output; + const auto& new_all_token_ids = sampler_output.token_ids; + const size_t token_stride = new_all_token_ids.size(1); + + auto cur_batch_size = stream->currentBatchSize(); + auto next_batch_size = stream->nextBatchSize(); + auto token_size = stream->currentExecuteTokenSize(); + + auto batch_new_all_token_ids = new_all_token_ids.narrow(0, batch_idx_out, next_batch_size); + + bool has_beam_search = stream->currentNumBeams() > 1 || stream->nextNumBeams() > 1; + bool has_var_batch = stream->currentBatchSize() != stream->nextBatchSize(); + + // construct mapping from output batches to input batches + torch::Tensor src_batch_indices; + if (has_beam_search) { + // beam search + src_batch_indices = sampler_output.beam_index.narrow(0, batch_idx_out, next_batch_size); + } else if (has_var_batch) { + // from context stream to decode straem, there might be other cases in future + src_batch_indices = torch::zeros({(int64_t)next_batch_size}, torch::kInt32); + } + const auto get_src_idx = [&](int32_t dst_idx) { + return src_batch_indices.defined() ? src_batch_indices.data_ptr()[dst_idx] : dst_idx; + }; + + // construct update info + torch::Tensor batch_hidden_states; + if (stream->generateConfig()->return_hidden_states) { + batch_hidden_states = model_output.hidden_states.narrow(0, batch_idx_in, cur_batch_size); + } + + torch::Tensor batch_logits; + if (stream->returnLogits() || stream->calculateSoftmaxProbs() || has_beam_search) { + batch_logits = model_output.logits.narrow(0, batch_idx_in, cur_batch_size); + } + + torch::Tensor all_probs; + if (return_all_probs) { + all_probs = sampler_output.all_probs.narrow(0, batch_idx_out, next_batch_size); + }; + + torch::Tensor batch_cum_log_probs; + if (sampler_output.cum_log_probs.defined()) { + batch_cum_log_probs = sampler_output.cum_log_probs.narrow(0, batch_idx_out, next_batch_size); + } + + torch::Tensor loss; + if (stream->calculateLoss()) { + auto all_logits_tensor = model_output.all_logits.narrow(0, token_offset, token_size - 1); + auto tokens = stream->currentExecuteTokens(0); + auto label_tensor = + torch::from_blob(const_cast(tokens.data() + 1), {(int64_t)(tokens.size() - 1)}, torch::kInt32) + .to(torch::kCUDA); + auto labels_int64 = label_tensor.toType(torch::kInt64); + loss = torch::cross_entropy_loss(all_logits_tensor, labels_int64, torch::nullopt, at::Reduction::None) + .to(torch::kFloat32); + } + + torch::Tensor all_hidden_states; + if (stream->needReturnHiddenStates()) { + all_hidden_states = model_output.all_hidden_states.narrow(0, token_offset, token_size); + } + + auto new_tokens = new_tokens_all.narrow(0, batch_idx_out, next_batch_size); + for (size_t i = 0; i < next_batch_size; ++i) { + new_tokens.data_ptr()[i] = + new_all_token_ids.data_ptr()[(batch_idx_out + i) * token_stride + token_stride - 1]; + } + + torch::Tensor current_softmax_result; + if (stream->calculateSoftmaxProbs()) { + auto batch_softmax_input = batch_logits.to(torch::kFloat32).contiguous(); +#if USING_CUDA + cudaSoftmaxInplace(batch_softmax_input, at::cuda::getCurrentCUDAStream().stream()); +#else + batch_softmax_input = torch::softmax(batch_softmax_input, -1); +#endif + auto batch_softmax_tensor = batch_softmax_input.cpu(); + current_softmax_result = torch::empty({(int64_t)next_batch_size, 1}, torch::kFloat32); + for (int i = 0; i < next_batch_size; ++i) { + current_softmax_result[i][0] = batch_softmax_tensor[get_src_idx(i)][new_tokens.data_ptr()[i]]; + } + } + + for (int i = 0; i < cur_batch_size; ++i) { + if (sampler_output.success.defined() && !(sampler_output.success.data_ptr()[batch_idx_in + i])) { + stream->setStop(ErrorCode::UNKNOWN_ERROR, "sampler generate token id failed"); + } + } + + RTP_LLM_LOG_DEBUG("stream [%ld], new_tokens size = [%ld]", stream->streamId(), new_tokens.numel()); + + stream->update({has_beam_search ? batch_new_all_token_ids : new_tokens, + 1, + batch_hidden_states, + batch_logits, + current_softmax_result, + batch_cum_log_probs, + all_probs, + loss, + src_batch_indices, + all_hidden_states}); +} + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/normal_engine/NormalOutputDispatcher.h b/rtp_llm/cpp/normal_engine/NormalOutputDispatcher.h new file mode 100644 index 0000000000..572755efd7 --- /dev/null +++ b/rtp_llm/cpp/normal_engine/NormalOutputDispatcher.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include "absl/status/status.h" +#include "rtp_llm/cpp/engine_base/stream/StreamGroups.h" +#include "rtp_llm/cpp/models/SampleInfos.h" + +namespace rtp_llm { + +class NormalOutputDispatcher { +public: + NormalOutputDispatcher() = default; + + absl::Status dispatch(const StreamGroups& stream_groups, const MergedOutput& merge_outputs) const; + +private: + void dispatchSingleStream(GenerateStreamPtr stream, + const MergedOutput& merge_outputs, + int batch_idx_in, + int batch_idx_out, + int token_offset, + bool return_all_probs, + const torch::Tensor& new_tokens_all) const; +}; + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/normal_engine/NormalSamplerInputGatherer.cc b/rtp_llm/cpp/normal_engine/NormalSamplerInputGatherer.cc new file mode 100644 index 0000000000..8aecf02b5a --- /dev/null +++ b/rtp_llm/cpp/normal_engine/NormalSamplerInputGatherer.cc @@ -0,0 +1,209 @@ +#include +#include +#include "torch/all.h" +#include "rtp_llm/cpp/normal_engine/NormalSamplerInputGatherer.h" +#include "rtp_llm/cpp/utils/AssertUtils.h" +#include "rtp_llm/cpp/models/logits_processor/LogitsProcessorStates.h" +#include "rtp_llm/cpp/utils/TensorDebugUtils.h" + +namespace rtp_llm { + +absl::StatusOr NormalSamplerInputGatherer::gather(const StreamGroups& stream_groups, + const GptModelInputs& model_inputs, + const GptModelOutputs& model_output) const { + (void)model_inputs; + RTP_LLM_LOG_DEBUG(__PRETTY_FUNCTION__); + RTP_LLM_CHECK(!stream_groups.empty()); + auto all_streams = stream_groups.allStreams(); + auto total_batch_size_in = stream_groups.totalSamplerBatchSizeIn(); + auto total_batch_size_out = stream_groups.totalSamplerBatchSizeOut(); + bool return_all_probs = stream_groups.needReturnAllProbs(); + + SamplerInputs sampler_inputs = allocateSamplerInputs(stream_groups, total_batch_size_in, total_batch_size_out); + fillSamplerCommonInputs(sampler_inputs, all_streams); + + setLogitsProcessorInputs(sampler_inputs, all_streams); + + size_t total_decode_batch_size_in = 0; + int batch_idx = 0; + bool return_logits = false; + bool calculate_softmax_probs = false; + bool need_tiling = false; + for (auto& stream : all_streams) { + auto complete_token_ids = stream->completeTokenIds(); + auto complete_seq_len = complete_token_ids.size(1); + auto seq_len = stream->seqLength(); + auto current_batch_size = stream->currentBatchSize(); + auto sampler_batch_size = + stream->needTilingForSampling() ? stream->nextBatchSize() : stream->currentBatchSize(); + + for (int i = 0; i < sampler_batch_size; ++i) { + int cur_batch = std::min(i, current_batch_size - 1); + memcpy(sampler_inputs.token_ids.data_ptr() + ((batch_idx) * (sampler_inputs.step + 1)), + complete_token_ids.data_ptr() + cur_batch * complete_seq_len, + seq_len * sizeof(int)); + reinterpret_cast(sampler_inputs.finished_mask.data_ptr())[batch_idx] = stream->isDoneWithoutLock(i); + batch_idx += 1; + } + need_tiling |= stream->needTilingForSampling(); + if (!stream->isContextStream()) { + total_decode_batch_size_in += sampler_batch_size; + } + return_logits |= stream->returnLogits(); + calculate_softmax_probs |= stream->calculateSoftmaxProbs(); + RTP_LLM_LOG_DEBUG("stream [%ld], sampler inputs token ids = [%s]", + stream->streamId(), + tensorDebugStringWithData(sampler_inputs.token_ids).c_str()); + } + + auto vocab_size = (size_t)model_output.logits.size(1); + sampler_inputs.vocab_size = vocab_size; + if (return_all_probs) { + sampler_inputs.all_probs = torch::zeros({(int64_t)total_batch_size_in, (int64_t)vocab_size}, + torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + } + + // copy logits when needs tiling or returning logits + torch::Tensor logits_tensor; + if (need_tiling) { + logits_tensor = + torch::empty({(int64_t)total_batch_size_in, (int64_t)vocab_size}, model_output.logits.options()); + // copy decode batch logits + if (total_decode_batch_size_in > 0) { + logits_tensor.narrow(0, 0, total_decode_batch_size_in) + .copy_(model_output.logits.narrow(0, 0, total_decode_batch_size_in)); + } + // tile context batch logits + size_t input_offset = total_decode_batch_size_in, logits_offset = total_decode_batch_size_in; + for (auto& stream : stream_groups.contextStreams()) { + auto sampler_batch_size = + stream->needTilingForSampling() ? stream->nextBatchSize() : stream->currentBatchSize(); + for (int i = 0; i < sampler_batch_size; ++i) { + logits_tensor[input_offset].copy_(model_output.logits[logits_offset]); + input_offset += 1; + } + logits_offset += 1; + } + } else if (return_logits || calculate_softmax_probs) { + logits_tensor = model_output.logits.clone(); + } else { + logits_tensor = model_output.logits; + } + sampler_inputs.logits = logits_tensor; + + RTP_LLM_LOG_DEBUG("sampler inputs logits [%s]", + tensorDebugStringWithData(sampler_inputs.logits.cpu(), 10).c_str()); + + RTP_LLM_LOG_DEBUG("gatherSamplerInput done"); + return std::move(sampler_inputs); +} + +SamplerInputs NormalSamplerInputGatherer::allocateSamplerInputs(const StreamGroups& stream_groups, + size_t total_batch_size_in, + size_t total_batch_size_out, + size_t propose_step) const { + // TODO(xinfei.sxf) don't sample for chunk stream + SamplerInputs sampler_inputs; + sampler_inputs.step = stream_groups.maxSeqLen() + propose_step; + sampler_inputs.batch_size = total_batch_size_in; + sampler_inputs.batch_size_out = total_batch_size_out; + auto bs = (int64_t)total_batch_size_in; + sampler_inputs.sequence_lengths = torch::empty({bs}, torch::kInt32); + sampler_inputs.logits_processor_states_ptr.reset(); + sampler_inputs.input_lengths = torch::empty({bs}, torch::kInt32); + sampler_inputs.num_beams_in = torch::empty({bs}, torch::kLong); + sampler_inputs.num_beams_out = torch::empty({bs}, torch::kLong); + static const auto pinned_int = torch::TensorOptions(torch::kInt).pinned_memory(true); + static const auto pinned_i32 = torch::TensorOptions(torch::kInt32).pinned_memory(true); + static const auto pinned_f32 = torch::TensorOptions(torch::kFloat32).pinned_memory(true); + static const auto pinned_bool = torch::TensorOptions(torch::kBool).pinned_memory(true); + + sampler_inputs.top_k = torch::empty({bs}, pinned_int); + sampler_inputs.top_p = torch::empty({bs}, pinned_f32); + sampler_inputs.temperature = torch::empty({bs}, pinned_f32); + sampler_inputs.repetition_penalty = torch::empty({bs}, pinned_f32); + sampler_inputs.presence_penalty = torch::empty({bs}, pinned_f32); + sampler_inputs.frequency_penalty = torch::empty({bs}, pinned_f32); + sampler_inputs.no_repeat_ngram_size = torch::empty({bs}, pinned_i32); + sampler_inputs.do_sample = torch::empty({bs}, pinned_bool); + sampler_inputs.finished_mask = torch::empty({bs}, torch::kBool); + if (stream_groups.needReturnCumLogProbs()) { + sampler_inputs.cum_log_probs = torch::empty({(int64_t)total_batch_size_in}, torch::kFloat32); + } + sampler_inputs.token_ids = + torch::empty({(int64_t)total_batch_size_in, (int64_t)(sampler_inputs.step + 1)}, torch::kInt32); + sampler_inputs.generator.resize(total_batch_size_in); + return sampler_inputs; +} + +void NormalSamplerInputGatherer::fillSamplerCommonInputs(SamplerInputs& sampler_inputs, + std::list& all_streams, + bool score_batch, + size_t propose_step) const { + int* input_lengths = sampler_inputs.input_lengths.data_ptr(); + int* sequence_lengths = sampler_inputs.sequence_lengths.data_ptr(); + uint64_t* num_beams_in = reinterpret_cast(sampler_inputs.num_beams_in.data_ptr()); + uint64_t* num_beams_out = reinterpret_cast(sampler_inputs.num_beams_out.data_ptr()); + uint32_t* top_k = reinterpret_cast(sampler_inputs.top_k.data_ptr()); + float* top_p = sampler_inputs.top_p.data_ptr(); + float* temperature = sampler_inputs.temperature.data_ptr(); + float* repetition_penalty = sampler_inputs.repetition_penalty.data_ptr(); + float* presence_penalty = sampler_inputs.presence_penalty.data_ptr(); + float* frequency_penalty = sampler_inputs.frequency_penalty.data_ptr(); + int32_t* no_repeat_ngram_size = sampler_inputs.no_repeat_ngram_size.data_ptr(); + bool* do_sample = reinterpret_cast(sampler_inputs.do_sample.data_ptr()); + + int batch_idx = 0; + for (auto& stream : all_streams) { + int sampler_batch_size; + if (score_batch) { + sampler_batch_size = stream->scoreLen(); + } else if (stream->needTilingForSampling()) { + sampler_batch_size = stream->nextBatchSize(); + } else { + sampler_batch_size = stream->currentBatchSize(); + } + if (sampler_inputs.cum_log_probs.defined()) { + const auto& cum_log_probs = stream->cumLogProbs(); + memcpy(sampler_inputs.cum_log_probs.data_ptr() + batch_idx, + cum_log_probs.data_ptr(), + cum_log_probs.numel() * sizeof(float)); + } + for (int i = 0; i < sampler_batch_size; ++i) { + input_lengths[batch_idx] = stream->inputLength(); + sequence_lengths[batch_idx] = stream->seqLength() + propose_step; + num_beams_in[batch_idx] = stream->currentNumBeams(); + num_beams_out[batch_idx] = stream->nextNumBeams(); + top_k[batch_idx] = stream->generateConfig()->top_k; + top_p[batch_idx] = stream->generateConfig()->top_p; + temperature[batch_idx] = stream->generateConfig()->temperature; + repetition_penalty[batch_idx] = stream->generateConfig()->repetition_penalty; + presence_penalty[batch_idx] = stream->generateConfig()->presence_penalty; + frequency_penalty[batch_idx] = stream->generateConfig()->frequency_penalty; + do_sample[batch_idx] = stream->generateConfig()->do_sample; + if (!do_sample[batch_idx]) { + top_k[batch_idx] = 1; + top_p[batch_idx] = 1; + temperature[batch_idx] = 1; + } + no_repeat_ngram_size[batch_idx] = stream->generateConfig()->no_repeat_ngram_size.value_or(0); + sampler_inputs.generator[batch_idx] = stream->getGenerator(); + batch_idx += 1; + } + } +} + +void NormalSamplerInputGatherer::setLogitsProcessorInputs(SamplerInputs& sampler_inputs, + std::list& all_streams, + bool score_batch) const { + LogitsProcessorStatesPtr state_ptr = std::make_shared(); + std::for_each(all_streams.begin(), all_streams.end(), [&state_ptr, idx = 0](auto& stream) mutable { + for (const auto& processor : stream->getAllLogitsProcessorPtr()) { + state_ptr->insert(processor, idx, idx + stream->currentBatchSize()); + } + idx += stream->currentBatchSize(); + }); + sampler_inputs.logits_processor_states_ptr = state_ptr; +} + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/normal_engine/NormalSamplerInputGatherer.h b/rtp_llm/cpp/normal_engine/NormalSamplerInputGatherer.h new file mode 100644 index 0000000000..380afc7e72 --- /dev/null +++ b/rtp_llm/cpp/normal_engine/NormalSamplerInputGatherer.h @@ -0,0 +1,35 @@ +#pragma once + +#include + +#include +#include "absl/status/statusor.h" +#include "rtp_llm/cpp/engine_base/stream/StreamGroups.h" +#include "rtp_llm/cpp/models/SampleInfos.h" + +namespace rtp_llm { + +class NormalSamplerInputGatherer { +public: + NormalSamplerInputGatherer() = default; + + absl::StatusOr gather(const StreamGroups& stream_groups, + const GptModelInputs& model_inputs, + const GptModelOutputs& model_output) const; + + SamplerInputs allocateSamplerInputs(const StreamGroups& stream_groups, + size_t total_batch_size_in, + size_t total_batch_size_out, + size_t propose_step = 0) const; + + void fillSamplerCommonInputs(SamplerInputs& sampler_inputs, + std::list& all_streams, + bool score_batch = false, + size_t propose_step = 0) const; + + void setLogitsProcessorInputs(SamplerInputs& sampler_inputs, + std::list& all_streams, + bool score_batch = false) const; +}; + +} // namespace rtp_llm diff --git a/rtp_llm/cpp/normal_engine/speculative/MtpBatchStreamProcessor.cc b/rtp_llm/cpp/normal_engine/speculative/MtpBatchStreamProcessor.cc index 15e4ddb937..7ae3ce0c6e 100644 --- a/rtp_llm/cpp/normal_engine/speculative/MtpBatchStreamProcessor.cc +++ b/rtp_llm/cpp/normal_engine/speculative/MtpBatchStreamProcessor.cc @@ -64,6 +64,7 @@ MtpBatchStreamProcessor::gatherDecodeModelInput(const StreamGroups& stream_group absl::StatusOr MtpBatchStreamProcessor::gatherSpecSamplerInput( const StreamGroups& stream_groups, const GptModelInputs& model_inputs, const GptModelOutputs& model_output) const { + (void)model_inputs; RTP_LLM_CHECK(!stream_groups.empty()); auto all_streams = stream_groups.allStreams(); bool return_all_probs = stream_groups.needReturnAllProbs(); @@ -75,9 +76,9 @@ absl::StatusOr MtpBatchStreamProcessor::gatherSpecSamplerInput( size_t score_len = propose_step_ + 1; size_t total_batch_size = stream_groups.size() * score_len; - SamplerInputs sampler_inputs = allocateSamplerInputs( - stream_groups, total_batch_size, total_batch_size, model_inputs.sequence_lengths, propose_step_); - setCommonSamplerInputs(sampler_inputs, all_streams, true, propose_step_); + SamplerInputs sampler_inputs = + allocateSamplerInputs(stream_groups, total_batch_size, total_batch_size, propose_step_); + fillSamplerCommonInputs(sampler_inputs, all_streams, true, propose_step_); int batch_idx = 0; for (auto& stream : all_streams) { diff --git a/rtp_llm/cpp/normal_engine/speculative/test/MtpBatchStreamProcessorTest.cc b/rtp_llm/cpp/normal_engine/speculative/test/MtpBatchStreamProcessorTest.cc index a04c1238ac..52a7467f4d 100644 --- a/rtp_llm/cpp/normal_engine/speculative/test/MtpBatchStreamProcessorTest.cc +++ b/rtp_llm/cpp/normal_engine/speculative/test/MtpBatchStreamProcessorTest.cc @@ -84,6 +84,7 @@ TEST_F(MtpBatchStreamProcessorTest, testPrefillDispatch) { PDSepConfig pd_sep_config; ProfilingDebugLoggingConfig profiling_debug_logging_config; CacheConfig cache_config; + cache_config.group_types = {CacheGroupType::FULL}; model_config.max_seq_len = 2048; model_config.vocab_size = 4; @@ -101,7 +102,6 @@ TEST_F(MtpBatchStreamProcessorTest, testPrefillDispatch) { MtpBatchStreamProcessor processor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, sp_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); StreamGroups stream_groups(streams); @@ -167,9 +167,9 @@ TEST_F(MtpBatchStreamProcessorTest, testDispatchDecodeStream) { draft_prefill_output.sampler_output.all_probs = torch::tensor({0.2f, 0.1f, 0.3f, 0.5f, 0.3f, 0.1f, 0.4f, 0.2f}, torch::kFloat32).reshape({2, 4}); + cache_config.group_types = {CacheGroupType::FULL}; MtpBatchStreamProcessor processor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, sp_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); auto status = processor.dispatchDecode(stream_groups, spec_decode_output, std::move(draft_prefill_output)); EXPECT_TRUE(status.ok()); @@ -216,9 +216,9 @@ TEST_F(MtpBatchStreamProcessorTest, testGatherDecodeModelInput) { auto stream_groups = StreamGroups({stream1, stream2}); - auto processor = MtpBatchStreamProcessor( + cache_config.group_types = {CacheGroupType::FULL}; + auto processor = MtpBatchStreamProcessor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, sp_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); auto model_input = processor.gatherDecodeModelInput(stream_groups); EXPECT_TRUE(model_input.ok()); @@ -293,9 +293,9 @@ TEST_F(MtpBatchStreamProcessorTest, testPrepareOneStepSpecDecodeModelInput) { auto stream_groups = StreamGroups({stream1, stream2}); - auto processor = MtpBatchStreamProcessor( + cache_config.group_types = {CacheGroupType::FULL}; + auto processor = MtpBatchStreamProcessor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, sp_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); auto model_input_status = processor.gatherDecodeModelInput(stream_groups); EXPECT_TRUE(model_input_status.ok()); @@ -391,9 +391,9 @@ TEST_F(MtpBatchStreamProcessorTest, testprepareDecodeDraftModelInput) { auto stream_groups = StreamGroups({stream1, stream2}); - auto processor = MtpBatchStreamProcessor( + cache_config.group_types = {CacheGroupType::FULL}; + auto processor = MtpBatchStreamProcessor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, sp_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); auto model_input_status = processor.gatherDecodeModelInput(stream_groups); EXPECT_TRUE(model_input_status.ok()); @@ -446,9 +446,9 @@ TEST_F(MtpBatchStreamProcessorTest, testUpdatePrefillPostDraftModelInput) { auto stream_groups = StreamGroups({stream1, stream2}); - auto processor = MtpBatchStreamProcessor( + cache_config.group_types = {CacheGroupType::FULL}; + auto processor = MtpBatchStreamProcessor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, sp_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); auto model_input_status = processor.gatherModelInput(stream_groups); EXPECT_TRUE(model_input_status.ok()); @@ -504,9 +504,9 @@ TEST_F(MtpBatchStreamProcessorTest, testUpdateDecodePostDraftModelInput) { auto stream_groups = StreamGroups({stream1, stream2}); - auto processor = MtpBatchStreamProcessor( + cache_config.group_types = {CacheGroupType::FULL}; + auto processor = MtpBatchStreamProcessor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, sp_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); auto model_input_status = processor.gatherModelInput(stream_groups); EXPECT_TRUE(model_input_status.ok()); diff --git a/rtp_llm/cpp/normal_engine/test/NormalBatchStreamProcessorTest.cc b/rtp_llm/cpp/normal_engine/test/NormalBatchStreamProcessorTest.cc index f3e2a14ec5..fc442f1741 100644 --- a/rtp_llm/cpp/normal_engine/test/NormalBatchStreamProcessorTest.cc +++ b/rtp_llm/cpp/normal_engine/test/NormalBatchStreamProcessorTest.cc @@ -36,11 +36,11 @@ TEST_F(NormalBatchStreamProcessorTest, testSimpleAssemble) { PDSepConfig pd_sep_config; ProfilingDebugLoggingConfig profiling_debug_logging_config; CacheConfig cache_config; + cache_config.group_types = {CacheGroupType::FULL}; RuntimeConfig runtime_config; NormalBatchStreamProcessor processor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); std::shared_ptr query1 = make_shared(); query1->input_ids = hostIntBuffer({1, 2}); @@ -124,7 +124,6 @@ TEST_F(NormalBatchStreamProcessorTest, testSimpleAssemble) { model_config.mm_model_config = mm_model_config; NormalBatchStreamProcessor processor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); StreamGroups stream_groups(streams); auto merge_input_status = processor.gatherModelInput(stream_groups); @@ -163,9 +162,9 @@ TEST_F(NormalBatchStreamProcessorTest, testSoftmaxProbs) { for (const auto& stream : streams) { stream->setRunning(); } + cache_config.group_types = {CacheGroupType::FULL}; NormalBatchStreamProcessor processor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); StreamGroups stream_groups(streams); auto merge_input_status = processor.gatherModelInput(stream_groups); @@ -242,9 +241,9 @@ TEST_F(NormalBatchStreamProcessorTest, testLoss) { for (const auto& stream : streams) { stream->setRunning(); } + cache_config.group_types = {CacheGroupType::FULL}; NormalBatchStreamProcessor processor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); StreamGroups stream_groups(streams); auto merge_input_status = processor.gatherModelInput(stream_groups); @@ -289,10 +288,10 @@ TEST_F(NormalBatchStreamProcessorTest, testMultimodalGatherBatch) { PDSepConfig pd_sep_config; ProfilingDebugLoggingConfig profiling_debug_logging_config; CacheConfig cache_config; - RuntimeConfig runtime_config; - NormalBatchStreamProcessor processor( + cache_config.group_types = {CacheGroupType::FULL}; + RuntimeConfig runtime_config; + NormalBatchStreamProcessor processor( model_config, pd_sep_config, profiling_debug_logging_config, cache_config, false); - processor.setKVCacheGroupTypes({CacheGroupType::FULL}); std::shared_ptr query1 = make_shared(); query1->input_ids = hostIntBuffer({1, -1, -1, -1, 2});