Skip to content

Commit f13abd8

Browse files
committed
refactor batch stream processor
1 parent 44027b7 commit f13abd8

13 files changed

Lines changed: 979 additions & 729 deletions

rtp_llm/cpp/normal_engine/NormalBatchStreamProcessor.cc

Lines changed: 54 additions & 631 deletions
Large diffs are not rendered by default.
Lines changed: 22 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
#pragma once
22

33
#include <memory>
4+
#include "absl/status/status.h"
5+
#include "absl/status/statusor.h"
6+
#include "rtp_llm/cpp/cache/CacheConfig.h"
7+
#include "rtp_llm/cpp/config/ConfigModules.h"
48
#include "rtp_llm/cpp/devices/DeviceBase.h"
59
#include "rtp_llm/cpp/cache/CacheGroupType.h"
6-
#include "rtp_llm/cpp/config/ConfigModules.h"
7-
#include "rtp_llm/cpp/models/SampleInfos.h"
810
#include "rtp_llm/cpp/engine_base/stream/StreamGroups.h"
9-
#include "absl/status/statusor.h"
10-
#include "absl/status/status.h"
11+
#include "rtp_llm/cpp/models/SampleInfos.h"
12+
#include "rtp_llm/cpp/normal_engine/NormalModelInputGatherer.h"
13+
#include "rtp_llm/cpp/normal_engine/NormalOutputDispatcher.h"
14+
#include "rtp_llm/cpp/normal_engine/NormalSamplerInputGatherer.h"
1115

1216
namespace rtp_llm {
1317

@@ -17,84 +21,29 @@ class NormalBatchStreamProcessor {
1721
const PDSepConfig& pd_sep_config,
1822
const ProfilingDebugLoggingConfig& profiling_debug_logging_config,
1923
const CacheConfig& cache_config,
20-
bool warm_up):
21-
num_layers_(model_config.num_layers),
22-
vocab_size_(model_config.vocab_size),
23-
input_vocab_size_(model_config.input_vocab_size),
24-
use_int8_kv_cache_(model_config.attn_config.kv_cache_dtype == rtp_llm::KvCacheDataType::INT8),
25-
has_positional_encoding_(model_config.has_positional_encoding),
26-
is_multimodal_(model_config.mm_model_config.is_multimodal),
27-
mm_position_ids_style_((PositionIdsStyle)model_config.mm_model_config.mm_position_ids_style),
28-
position_id_len_factor_(model_config.attn_config.rope_config.index_factor),
29-
role_type_(pd_sep_config.role_type),
30-
decode_entrance_(pd_sep_config.decode_entrance),
31-
block_stride_bytes_(cache_config.kv_block_stride_bytes),
32-
scale_stride_bytes_(cache_config.kv_scale_stride_bytes),
33-
seq_size_per_block_(cache_config.seq_size_per_block),
34-
kernel_seq_size_per_block_(cache_config.kernel_seq_size_per_block),
35-
kernel_blocks_per_kv_block_(cache_config.kernelBlocksPerKvBlock()),
36-
kv_cache_group_nums_(cache_config.groupNums()),
37-
layer_to_kv_cache_group_id_(cache_config.layer_to_group_id),
38-
kv_cache_group_types_(cache_config.group_types),
39-
warm_up_(warm_up),
40-
enable_detail_log_(profiling_debug_logging_config.enable_detail_log),
41-
device_(rtp_llm::DeviceFactory::getDefaultDevice()) {}
24+
bool warm_up);
4225

4326
virtual absl::Status dispatch(const StreamGroups& stream_groups, const MergedOutput& merge_outputs) const;
4427
virtual absl::StatusOr<GptModelInputs> gatherModelInput(const StreamGroups& stream_groups) const;
4528
virtual absl::StatusOr<SamplerInputs> gatherSamplerInput(const StreamGroups& stream_groups,
46-
const GptModelInputs& model_inputs,
4729
const GptModelOutputs& model_output) const;
4830

4931
protected:
50-
SamplerInputs allocateSamplerInputs(const StreamGroups& stream_groups,
51-
size_t total_batch_size_in,
52-
size_t total_batch_size_out,
53-
const rtp_llm::BufferPtr& sequence_length,
54-
size_t propose_step = 0) const;
55-
void setCommonSamplerInputs(SamplerInputs& sampler_inputs,
56-
std::list<GenerateStreamPtr>& all_streams,
57-
bool score_batch = false,
58-
size_t propose_step = 0) const;
59-
void setLogitsProcessorInputs(SamplerInputs& sampler_inputs,
60-
std::list<GenerateStreamPtr>& all_streams,
61-
bool score_batch = false) const;
62-
63-
void dispatchSingleStream(GenerateStreamPtr stream,
64-
const MergedOutput& merge_outputs,
65-
int batch_idx_in,
66-
int batch_idx_out,
67-
int token_offset,
68-
bool return_all_probs,
69-
const BufferPtr& new_tokens_all) const;
70-
71-
void setKVCacheGroupTypes(std::vector<CacheGroupType> kv_cache_group_types) {
72-
kv_cache_group_types_ = kv_cache_group_types;
73-
}
74-
75-
protected:
76-
size_t num_layers_;
77-
size_t vocab_size_;
78-
size_t input_vocab_size_;
79-
bool use_int8_kv_cache_;
80-
bool has_positional_encoding_;
81-
bool is_multimodal_;
82-
PositionIdsStyle mm_position_ids_style_;
83-
size_t position_id_len_factor_;
84-
RoleType role_type_;
85-
bool decode_entrance_;
86-
size_t block_stride_bytes_;
87-
size_t scale_stride_bytes_;
88-
size_t seq_size_per_block_;
89-
size_t kernel_seq_size_per_block_;
90-
size_t kernel_blocks_per_kv_block_ = 1;
91-
size_t kv_cache_group_nums_ = 1;
92-
mutable std::vector<int32_t> layer_to_kv_cache_group_id_;
93-
std::vector<CacheGroupType> kv_cache_group_types_;
94-
bool warm_up_;
95-
bool enable_detail_log_;
32+
SamplerInputs allocateSamplerInputs(const StreamGroups& stream_groups,
33+
size_t total_batch_size_in,
34+
size_t total_batch_size_out,
35+
size_t propose_step) const;
36+
void fillSamplerCommonInputs(SamplerInputs& sampler_inputs,
37+
std::list<GenerateStreamPtr>& all_streams,
38+
bool score_batch = false,
39+
size_t propose_step = 0) const;
9640

9741
rtp_llm::DeviceBase* device_;
42+
size_t vocab_size_;
43+
44+
std::unique_ptr<NormalModelInputGatherer> model_input_gatherer_;
45+
std::unique_ptr<NormalSamplerInputGatherer> sampler_input_gatherer_;
46+
std::unique_ptr<NormalOutputDispatcher> output_dispatcher_;
9847
};
9948

10049
} // namespace rtp_llm

rtp_llm/cpp/normal_engine/NormalExecutor.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,7 @@ absl::Status NormalExecutor::process(const std::list<GenerateStreamPtr>& streams
170170
{
171171
RTP_LLM_PROFILE_SCOPE("executor.sampler_forward");
172172
int64_t start_time_us = autil::TimeUtility::currentTimeInMicroSeconds();
173-
CHECK_AND_RETURN_REF(sampler_input,
174-
batch_stream_processor_->gatherSamplerInput(stream_groups, model_input, model_output));
173+
CHECK_AND_RETURN_REF(sampler_input, batch_stream_processor_->gatherSamplerInput(stream_groups, model_output));
175174
sampler_output = std::move(sampler_->forward(sampler_input));
176175
RTP_LLM_LOG_DEBUG("sampler forward done");
177176
executor_collector.sample_input_us = autil::TimeUtility::currentTimeInMicroSeconds() - start_time_us;

0 commit comments

Comments
 (0)