11// Copyright (c) OpenMMLab. All rights reserved.
22
33#include < cstddef>
4+ #include < cstdint>
45#include < cstdlib>
56#include < ctime>
7+ #include < limits>
68#include < numeric>
79
810#include " src/turbomind/core/logger.h"
@@ -56,14 +58,25 @@ size_t GetHybridCheckpointSlotBytes(const std::vector<ssize_t>& conv_state_shape
5658 + GetShapeBytes (recurrent_state_shape, recurrent_state_dtype);
5759}
5860
59- unsigned __int128 GetHybridCacheBytesForBlocks (size_t kv_block_count,
60- size_t kv_block_bytes,
61- size_t checkpoint_slot_bytes,
62- int checkpoint_interval_blocks)
61+ uint64_t GetHybridCacheBytesForBlocks (size_t kv_block_count,
62+ size_t kv_block_bytes,
63+ size_t checkpoint_slot_bytes,
64+ int checkpoint_interval_blocks)
6365{
6466 const size_t checkpoint_slots = checkpoint_interval_blocks > 0 ? kv_block_count / checkpoint_interval_blocks : 0 ;
65- return static_cast <unsigned __int128>(kv_block_count) * kv_block_bytes
66- + static_cast <unsigned __int128>(checkpoint_slots) * checkpoint_slot_bytes;
67+
68+ TM_CHECK (kv_block_count == 0 || kv_block_bytes <= std::numeric_limits<uint64_t >::max () / kv_block_count)
69+ << " KV cache byte count exceeds uint64_t range." ;
70+ const uint64_t kv_bytes = static_cast <uint64_t >(kv_block_count) * static_cast <uint64_t >(kv_block_bytes);
71+
72+ TM_CHECK (checkpoint_slots == 0 || checkpoint_slot_bytes <= std::numeric_limits<uint64_t >::max () / checkpoint_slots)
73+ << " Hybrid checkpoint byte count exceeds uint64_t range." ;
74+ const uint64_t checkpoint_bytes =
75+ static_cast <uint64_t >(checkpoint_slots) * static_cast <uint64_t >(checkpoint_slot_bytes);
76+
77+ TM_CHECK (kv_bytes <= std::numeric_limits<uint64_t >::max () - checkpoint_bytes)
78+ << " Total hybrid cache bytes exceed uint64_t range." ;
79+ return kv_bytes + checkpoint_bytes;
6780}
6881
6982size_t SolveHybridKvBlockCount (size_t budget_bytes,
@@ -81,10 +94,10 @@ size_t SolveHybridKvBlockCount(size_t budget_bytes,
8194 size_t lo = 0 ;
8295 size_t hi = budget_bytes / kv_block_bytes;
8396 while (lo < hi) {
84- const size_t mid = lo + (hi - lo + 1 ) / 2 ;
85- const auto required_bytes =
97+ const size_t mid = lo + (hi - lo + 1 ) / 2 ;
98+ const uint64_t required_bytes =
8699 GetHybridCacheBytesForBlocks (mid, kv_block_bytes, checkpoint_slot_bytes, checkpoint_interval_blocks);
87- if (required_bytes <= budget_bytes) {
100+ if (required_bytes <= static_cast < uint64_t >( budget_bytes) ) {
88101 lo = mid;
89102 }
90103 else {
@@ -223,12 +236,18 @@ SequenceManager::SequenceManager(const ModelParam& model_param,
223236 }
224237 else if (num_linear_layers > 0 && block_count >= 1 .) {
225238 const size_t requested_blocks = static_cast <size_t >(block_count);
226- const auto requested_cache_bytes =
227- has_linear_prefix_checkpoints ? hybrid_prefix_budget::GetHybridCacheBytesForBlocks (
228- requested_blocks, block_size, linear_prefix_slot_bytes, linear_prefix_cache_interval_blocks_) :
229- static_cast <unsigned __int128>(requested_blocks) * block_size;
239+ uint64_t requested_cache_bytes{};
240+ if (has_linear_prefix_checkpoints) {
241+ requested_cache_bytes = hybrid_prefix_budget::GetHybridCacheBytesForBlocks (
242+ requested_blocks, block_size, linear_prefix_slot_bytes, linear_prefix_cache_interval_blocks_);
243+ }
244+ else {
245+ TM_CHECK (requested_blocks == 0 || block_size <= std::numeric_limits<uint64_t >::max () / requested_blocks)
246+ << " Requested KV cache byte count exceeds uint64_t range." ;
247+ requested_cache_bytes = static_cast <uint64_t >(requested_blocks) * static_cast <uint64_t >(block_size);
248+ }
230249 const size_t available_after_live = get_free_size ();
231- TM_CHECK (requested_cache_bytes <= static_cast <unsigned __int128 >(available_after_live))
250+ TM_CHECK (requested_cache_bytes <= static_cast <uint64_t >(available_after_live))
232251 << " Insufficient memory for "
233252 << (has_linear_prefix_checkpoints ? " hybrid prefix cache blocks and checkpoints." : " KV cache blocks." );
234253 }
0 commit comments