|
| 1 | +#include "chunk_prefill_compiler.hpp" |
| 2 | +#include "../../global_state/global_state.hpp" |
| 3 | +#include "infinicore/context/context.hpp" |
| 4 | + |
| 5 | + |
| 6 | +namespace infinilm::engine { |
| 7 | + |
| 8 | +ChunkPrefillCompiler::ChunkPrefillCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) |
| 9 | + : GraphCompiler(model, barrier) { |
| 10 | + // Enumerate chunk sizes for chunk-prefill |
| 11 | + for (size_t cs : {256}) { |
| 12 | + chunk_sizes_.push_back(cs); |
| 13 | + } |
| 14 | + // Enumerate batch sizes for prefill (typically smaller than decode) |
| 15 | + for (size_t b = 1; b < 32; b++) { |
| 16 | + prefill_batch_sizes_.push_back(b); |
| 17 | + } |
| 18 | + for (size_t b = 32; b < 64; b += 8) { |
| 19 | + prefill_batch_sizes_.push_back(b); |
| 20 | + } |
| 21 | + for (size_t b = 64; b < 128; b += 16) { |
| 22 | + prefill_batch_sizes_.push_back(b); |
| 23 | + } |
| 24 | + for (size_t b = 128; b < 256; b += 32) { |
| 25 | + prefill_batch_sizes_.push_back(b); |
| 26 | + } |
| 27 | + for (size_t b = 256; b <= 512; b += 64) { |
| 28 | + prefill_batch_sizes_.push_back(b); |
| 29 | + } |
| 30 | +} |
| 31 | + |
| 32 | +void ChunkPrefillCompiler::compile() { |
| 33 | + if (model_->get_cache_config() != nullptr && |
| 34 | + dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) { |
| 35 | + |
| 36 | + const auto *paged_config = |
| 37 | + dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config()); |
| 38 | + size_t nblocks = paged_config->num_blocks(); |
| 39 | + |
| 40 | + compiled_map_prefill_.clear(); |
| 41 | + |
| 42 | + // Max total tokens to avoid OOM during graph recording |
| 43 | + constexpr size_t MAX_TOTAL_TOKENS = 4096; |
| 44 | + |
| 45 | + // Pre-allocate a shared block_tables_holder for the largest (batch_size) we'll use |
| 46 | + size_t max_batch = *std::max_element(prefill_batch_sizes_.begin(), prefill_batch_sizes_.end()); |
| 47 | + size_t block_per_req = nblocks / max_batch; |
| 48 | + block_tables_holder_ = infinicore::Tensor::zeros( |
| 49 | + {nblocks}, infinicore::DataType::I32, infinicore::context::getDevice()); |
| 50 | + |
| 51 | + for (size_t b : prefill_batch_sizes_) { |
| 52 | + for (size_t cs : chunk_sizes_) { |
| 53 | + size_t total_tokens = b * cs; |
| 54 | + if (total_tokens > MAX_TOTAL_TOKENS) { |
| 55 | + continue; |
| 56 | + } |
| 57 | + |
| 58 | + size_t bpr = nblocks / b; // block_per_req for this batch size |
| 59 | + |
| 60 | + InfinilmModel::Input input; |
| 61 | + |
| 62 | + // input_ids: [1, total_tokens] — all tokens for this batch packed together |
| 63 | + input.input_ids = infinicore::Tensor::zeros( |
| 64 | + {1, total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); |
| 65 | + |
| 66 | + // position_ids: [total_tokens] |
| 67 | + input.position_ids = infinicore::Tensor::zeros( |
| 68 | + {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); |
| 69 | + |
| 70 | + // total_sequence_lengths: [b], set to cs (first-chunk scenario) |
| 71 | + input.total_sequence_lengths = infinicore::Tensor::empty( |
| 72 | + {b}, infinicore::DataType::I32, infinicore::context::getDevice()); |
| 73 | + { |
| 74 | + std::vector<int32_t> tsl(b, static_cast<int32_t>(cs)); |
| 75 | + infinicore::context::memcpyH2D( |
| 76 | + input.total_sequence_lengths.value()->data(), |
| 77 | + tsl.data(), b * sizeof(int32_t), false); |
| 78 | + } |
| 79 | + |
| 80 | + // input_offsets: [b+1], stride = cs |
| 81 | + input.input_offsets = infinicore::Tensor::empty( |
| 82 | + {b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); |
| 83 | + { |
| 84 | + std::vector<int32_t> offsets(b + 1); |
| 85 | + for (size_t i = 0; i <= b; i++) { |
| 86 | + offsets[i] = static_cast<int32_t>(i * cs); |
| 87 | + } |
| 88 | + infinicore::context::memcpyH2D( |
| 89 | + input.input_offsets.value()->data(), |
| 90 | + offsets.data(), (b + 1) * sizeof(int32_t), false); |
| 91 | + } |
| 92 | + |
| 93 | + // cu_seqlens: [b+1], same layout as input_offsets for prefill |
| 94 | + input.cu_seqlens = infinicore::Tensor::empty( |
| 95 | + {b + 1}, infinicore::DataType::I32, infinicore::context::getDevice()); |
| 96 | + { |
| 97 | + std::vector<int32_t> cu(b + 1); |
| 98 | + for (size_t i = 0; i <= b; i++) { |
| 99 | + cu[i] = static_cast<int32_t>(i * cs); |
| 100 | + } |
| 101 | + infinicore::context::memcpyH2D( |
| 102 | + input.cu_seqlens.value()->data(), |
| 103 | + cu.data(), (b + 1) * sizeof(int32_t), false); |
| 104 | + } |
| 105 | + |
| 106 | + // block_tables: view into the shared holder [b, bpr] |
| 107 | + input.block_tables = block_tables_holder_->as_strided( |
| 108 | + {b, bpr}, {(ptrdiff_t)bpr, 1}); |
| 109 | + |
| 110 | + // slot_mapping: [total_tokens] |
| 111 | + input.slot_mapping = infinicore::Tensor::zeros( |
| 112 | + {total_tokens}, infinicore::DataType::I64, infinicore::context::getDevice()); |
| 113 | + |
| 114 | + // Attention reads attn_metadata from thread-local forward context. |
| 115 | + infinilm::global_state::get_forward_context().attn_metadata = { |
| 116 | + input.past_sequence_lengths, |
| 117 | + input.total_sequence_lengths, |
| 118 | + input.input_offsets, |
| 119 | + input.cu_seqlens, |
| 120 | + input.block_tables, |
| 121 | + input.slot_mapping, |
| 122 | + }; |
| 123 | + |
| 124 | + barrier_->wait(); |
| 125 | + infinicore::context::startGraphRecording(); |
| 126 | + auto output = model_->forward(input); |
| 127 | + auto graph = infinicore::context::stopGraphRecording(); |
| 128 | + barrier_->wait(); |
| 129 | + |
| 130 | + auto shared_output = std::shared_ptr<InfinilmModel::Output>( |
| 131 | + new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)}); |
| 132 | + |
| 133 | + compiled_map_prefill_[std::make_tuple(b, cs)] = |
| 134 | + CompiledResult{std::move(input), std::make_tuple(graph, shared_output)}; |
| 135 | + } |
| 136 | + } |
| 137 | + } |
| 138 | +} |
| 139 | + |
| 140 | +ChunkPrefillCompiler::Compiled ChunkPrefillCompiler::get_compiled(const InfinilmModel::Input &input) { |
| 141 | + if (model_->get_cache_config() == nullptr || |
| 142 | + !dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) { |
| 143 | + return {nullptr, nullptr}; |
| 144 | + } |
| 145 | + |
| 146 | + if (!input.block_tables.has_value() || !input.input_ids.has_value()) { |
| 147 | + return {nullptr, nullptr}; |
| 148 | + } |
| 149 | + |
| 150 | + size_t batch_size = input.block_tables.value()->size(0); |
| 151 | + size_t block_per_req = input.block_tables.value()->size(1); |
| 152 | + size_t total_tokens = input.input_ids.value()->size(1); |
| 153 | + |
| 154 | + // Prefill: total_tokens is a multiple of batch_size, and chunk_size > 1 |
| 155 | + if (total_tokens == 0 || total_tokens % batch_size != 0) { |
| 156 | + return {nullptr, nullptr}; |
| 157 | + } |
| 158 | + size_t chunk_size = total_tokens / batch_size; |
| 159 | + if (chunk_size <= 1) { |
| 160 | + // Single-token case belongs to decode |
| 161 | + return {nullptr, nullptr}; |
| 162 | + } |
| 163 | + |
| 164 | + auto result = compiled_map_prefill_.find(std::make_tuple(batch_size, chunk_size)); |
| 165 | + if (result == compiled_map_prefill_.end()) { |
| 166 | + return {nullptr, nullptr}; |
| 167 | + } |
| 168 | + |
| 169 | + auto &graph_input = result->second.input; |
| 170 | + |
| 171 | + graph_input.input_ids.value()->copy_from(input.input_ids.value()); |
| 172 | + graph_input.position_ids.value()->copy_from(input.position_ids.value()); |
| 173 | + graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); |
| 174 | + graph_input.input_offsets.value()->copy_from(input.input_offsets.value()); |
| 175 | + graph_input.cu_seqlens.value()->copy_from(input.cu_seqlens.value()); |
| 176 | + graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value()); |
| 177 | + graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value()); |
| 178 | + |
| 179 | + auto graph = std::get<0>(result->second.compiled); |
| 180 | + auto shared_output = std::shared_ptr<InfinilmModel::Output>( |
| 181 | + new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); |
| 182 | + |
| 183 | + return std::make_tuple(graph, shared_output); |
| 184 | +} |
| 185 | + |
| 186 | +} // namespace infinilm::engine |
0 commit comments