11#pragma once
22
33#include < memory>
4+ #include " absl/status/status.h"
5+ #include " absl/status/statusor.h"
6+ #include " rtp_llm/cpp/cache/CacheConfig.h"
7+ #include " rtp_llm/cpp/config/ConfigModules.h"
48#include " rtp_llm/cpp/devices/DeviceBase.h"
59#include " rtp_llm/cpp/cache/CacheGroupType.h"
6- #include " rtp_llm/cpp/config/ConfigModules.h"
7- #include " rtp_llm/cpp/models/SampleInfos.h"
810#include " rtp_llm/cpp/engine_base/stream/StreamGroups.h"
9- #include " absl/status/statusor.h"
10- #include " absl/status/status.h"
11+ #include " rtp_llm/cpp/models/SampleInfos.h"
12+ #include " rtp_llm/cpp/normal_engine/NormalModelInputGatherer.h"
13+ #include " rtp_llm/cpp/normal_engine/NormalOutputDispatcher.h"
14+ #include " rtp_llm/cpp/normal_engine/NormalSamplerInputGatherer.h"
1115
1216namespace rtp_llm {
1317
@@ -17,84 +21,29 @@ class NormalBatchStreamProcessor {
1721 const PDSepConfig& pd_sep_config,
1822 const ProfilingDebugLoggingConfig& profiling_debug_logging_config,
1923 const CacheConfig& cache_config,
20- bool warm_up):
21- num_layers_ (model_config.num_layers),
22- vocab_size_ (model_config.vocab_size),
23- input_vocab_size_ (model_config.input_vocab_size),
24- use_int8_kv_cache_ (model_config.attn_config.kv_cache_dtype == rtp_llm::KvCacheDataType::INT8),
25- has_positional_encoding_ (model_config.has_positional_encoding),
26- is_multimodal_ (model_config.mm_model_config.is_multimodal),
27- mm_position_ids_style_ ((PositionIdsStyle)model_config.mm_model_config.mm_position_ids_style),
28- position_id_len_factor_ (model_config.attn_config.rope_config.index_factor),
29- role_type_ (pd_sep_config.role_type),
30- decode_entrance_ (pd_sep_config.decode_entrance),
31- block_stride_bytes_ (cache_config.kv_block_stride_bytes),
32- scale_stride_bytes_ (cache_config.kv_scale_stride_bytes),
33- seq_size_per_block_ (cache_config.seq_size_per_block),
34- kernel_seq_size_per_block_ (cache_config.kernel_seq_size_per_block),
35- kernel_blocks_per_kv_block_ (cache_config.kernelBlocksPerKvBlock()),
36- kv_cache_group_nums_ (cache_config.groupNums()),
37- layer_to_kv_cache_group_id_ (cache_config.layer_to_group_id),
38- kv_cache_group_types_ (cache_config.group_types),
39- warm_up_ (warm_up),
40- enable_detail_log_ (profiling_debug_logging_config.enable_detail_log),
41- device_ (rtp_llm::DeviceFactory::getDefaultDevice()) {}
24+ bool warm_up);
4225
4326 virtual absl::Status dispatch (const StreamGroups& stream_groups, const MergedOutput& merge_outputs) const ;
4427 virtual absl::StatusOr<GptModelInputs> gatherModelInput (const StreamGroups& stream_groups) const ;
4528 virtual absl::StatusOr<SamplerInputs> gatherSamplerInput (const StreamGroups& stream_groups,
46- const GptModelInputs& model_inputs,
4729 const GptModelOutputs& model_output) const ;
4830
4931protected:
50- SamplerInputs allocateSamplerInputs (const StreamGroups& stream_groups,
51- size_t total_batch_size_in,
52- size_t total_batch_size_out,
53- const rtp_llm::BufferPtr& sequence_length,
54- size_t propose_step = 0 ) const ;
55- void setCommonSamplerInputs (SamplerInputs& sampler_inputs,
56- std::list<GenerateStreamPtr>& all_streams,
57- bool score_batch = false ,
58- size_t propose_step = 0 ) const ;
59- void setLogitsProcessorInputs (SamplerInputs& sampler_inputs,
60- std::list<GenerateStreamPtr>& all_streams,
61- bool score_batch = false ) const ;
62-
63- void dispatchSingleStream (GenerateStreamPtr stream,
64- const MergedOutput& merge_outputs,
65- int batch_idx_in,
66- int batch_idx_out,
67- int token_offset,
68- bool return_all_probs,
69- const BufferPtr& new_tokens_all) const ;
70-
71- void setKVCacheGroupTypes (std::vector<CacheGroupType> kv_cache_group_types) {
72- kv_cache_group_types_ = kv_cache_group_types;
73- }
74-
75- protected:
76- size_t num_layers_;
77- size_t vocab_size_;
78- size_t input_vocab_size_;
79- bool use_int8_kv_cache_;
80- bool has_positional_encoding_;
81- bool is_multimodal_;
82- PositionIdsStyle mm_position_ids_style_;
83- size_t position_id_len_factor_;
84- RoleType role_type_;
85- bool decode_entrance_;
86- size_t block_stride_bytes_;
87- size_t scale_stride_bytes_;
88- size_t seq_size_per_block_;
89- size_t kernel_seq_size_per_block_;
90- size_t kernel_blocks_per_kv_block_ = 1 ;
91- size_t kv_cache_group_nums_ = 1 ;
92- mutable std::vector<int32_t > layer_to_kv_cache_group_id_;
93- std::vector<CacheGroupType> kv_cache_group_types_;
94- bool warm_up_;
95- bool enable_detail_log_;
32+ SamplerInputs allocateSamplerInputs (const StreamGroups& stream_groups,
33+ size_t total_batch_size_in,
34+ size_t total_batch_size_out,
35+ size_t propose_step) const ;
36+ void fillSamplerCommonInputs (SamplerInputs& sampler_inputs,
37+ std::list<GenerateStreamPtr>& all_streams,
38+ bool score_batch = false ,
39+ size_t propose_step = 0 ) const ;
9640
9741 rtp_llm::DeviceBase* device_;
42+ size_t vocab_size_;
43+
44+ std::unique_ptr<NormalModelInputGatherer> model_input_gatherer_;
45+ std::unique_ptr<NormalSamplerInputGatherer> sampler_input_gatherer_;
46+ std::unique_ptr<NormalOutputDispatcher> output_dispatcher_;
9847};
9948
10049} // namespace rtp_llm
0 commit comments