@@ -16,7 +16,6 @@ limitations under the License.
1616#include " block_manager_pool.h"
1717
1818#include < algorithm>
19- #include < limits>
2019
2120#include " block_manager_impl.h"
2221#include " common/global_flags.h"
@@ -340,10 +339,30 @@ void BlockManagerPool::allocate_shared(Sequence* sequence) {
340339}
341340
342341void BlockManagerPool::cache (Sequence* sequence) {
342+ cache (sequence, sequence->kv_state ().kv_cache_tokens_num ());
343+ }
344+
345+ void BlockManagerPool::cache (Sequence* sequence, size_t num_tokens) {
346+ DCHECK (sequence != nullptr );
347+ if (!options_.enable_prefix_cache ()) {
348+ return ;
349+ }
350+
351+ const size_t block_size = static_cast <size_t >(options_.block_size ());
352+ const size_t available_tokens_num =
353+ std::min ({num_tokens,
354+ sequence->kv_state ().num_kv_blocks () * block_size,
355+ sequence->tokens ().size ()});
356+ const size_t existed_shared_blocks_num =
357+ sequence->kv_state ().shared_kv_blocks_num ();
358+ if (available_tokens_num <= existed_shared_blocks_num * block_size) {
359+ return ;
360+ }
361+
343362 int32_t dp_rank = get_dp_rank (sequence);
344- const auto token_ids = sequence->cached_tokens ( );
363+ const auto token_ids = sequence->tokens (). slice ( 0 , available_tokens_num );
345364 auto * blocks = sequence->kv_state ().mutable_kv_blocks ();
346- auto existed_shared_blocks_num = sequence-> kv_state (). shared_kv_blocks_num ( );
365+ CHECK_GE (blocks-> size (), existed_shared_blocks_num );
347366 block_managers_[dp_rank]->cache (
348367 token_ids, *blocks, existed_shared_blocks_num);
349368}
0 commit comments