Skip to content

Commit 678f918

Browse files
committed
Refactor GetHybridCacheBytesForBlocks to use uint64_t for improved memory safety, compatibility with MSVC and add checks for overflow conditions. Update related calculations in SequenceManager to ensure compatibility with new data types.
1 parent 710a8b2 commit 678f918

1 file changed

Lines changed: 33 additions & 14 deletions

File tree

src/turbomind/models/llama/SequenceManager.cc

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
// Copyright (c) OpenMMLab. All rights reserved.
22

33
#include <cstddef>
4+
#include <cstdint>
45
#include <cstdlib>
56
#include <ctime>
7+
#include <limits>
68
#include <numeric>
79

810
#include "src/turbomind/core/logger.h"
@@ -56,14 +58,25 @@ size_t GetHybridCheckpointSlotBytes(const std::vector<ssize_t>& conv_state_shape
5658
+ GetShapeBytes(recurrent_state_shape, recurrent_state_dtype);
5759
}
5860

59-
unsigned __int128 GetHybridCacheBytesForBlocks(size_t kv_block_count,
60-
size_t kv_block_bytes,
61-
size_t checkpoint_slot_bytes,
62-
int checkpoint_interval_blocks)
61+
uint64_t GetHybridCacheBytesForBlocks(size_t kv_block_count,
62+
size_t kv_block_bytes,
63+
size_t checkpoint_slot_bytes,
64+
int checkpoint_interval_blocks)
6365
{
6466
const size_t checkpoint_slots = checkpoint_interval_blocks > 0 ? kv_block_count / checkpoint_interval_blocks : 0;
65-
return static_cast<unsigned __int128>(kv_block_count) * kv_block_bytes
66-
+ static_cast<unsigned __int128>(checkpoint_slots) * checkpoint_slot_bytes;
67+
68+
TM_CHECK(kv_block_count == 0 || kv_block_bytes <= std::numeric_limits<uint64_t>::max() / kv_block_count)
69+
<< "KV cache byte count exceeds uint64_t range.";
70+
const uint64_t kv_bytes = static_cast<uint64_t>(kv_block_count) * static_cast<uint64_t>(kv_block_bytes);
71+
72+
TM_CHECK(checkpoint_slots == 0 || checkpoint_slot_bytes <= std::numeric_limits<uint64_t>::max() / checkpoint_slots)
73+
<< "Hybrid checkpoint byte count exceeds uint64_t range.";
74+
const uint64_t checkpoint_bytes =
75+
static_cast<uint64_t>(checkpoint_slots) * static_cast<uint64_t>(checkpoint_slot_bytes);
76+
77+
TM_CHECK(kv_bytes <= std::numeric_limits<uint64_t>::max() - checkpoint_bytes)
78+
<< "Total hybrid cache bytes exceed uint64_t range.";
79+
return kv_bytes + checkpoint_bytes;
6780
}
6881

6982
size_t SolveHybridKvBlockCount(size_t budget_bytes,
@@ -81,10 +94,10 @@ size_t SolveHybridKvBlockCount(size_t budget_bytes,
8194
size_t lo = 0;
8295
size_t hi = budget_bytes / kv_block_bytes;
8396
while (lo < hi) {
84-
const size_t mid = lo + (hi - lo + 1) / 2;
85-
const auto required_bytes =
97+
const size_t mid = lo + (hi - lo + 1) / 2;
98+
const uint64_t required_bytes =
8699
GetHybridCacheBytesForBlocks(mid, kv_block_bytes, checkpoint_slot_bytes, checkpoint_interval_blocks);
87-
if (required_bytes <= budget_bytes) {
100+
if (required_bytes <= static_cast<uint64_t>(budget_bytes)) {
88101
lo = mid;
89102
}
90103
else {
@@ -223,12 +236,18 @@ SequenceManager::SequenceManager(const ModelParam& model_param,
223236
}
224237
else if (num_linear_layers > 0 && block_count >= 1.) {
225238
const size_t requested_blocks = static_cast<size_t>(block_count);
226-
const auto requested_cache_bytes =
227-
has_linear_prefix_checkpoints ? hybrid_prefix_budget::GetHybridCacheBytesForBlocks(
228-
requested_blocks, block_size, linear_prefix_slot_bytes, linear_prefix_cache_interval_blocks_) :
229-
static_cast<unsigned __int128>(requested_blocks) * block_size;
239+
uint64_t requested_cache_bytes{};
240+
if (has_linear_prefix_checkpoints) {
241+
requested_cache_bytes = hybrid_prefix_budget::GetHybridCacheBytesForBlocks(
242+
requested_blocks, block_size, linear_prefix_slot_bytes, linear_prefix_cache_interval_blocks_);
243+
}
244+
else {
245+
TM_CHECK(requested_blocks == 0 || block_size <= std::numeric_limits<uint64_t>::max() / requested_blocks)
246+
<< "Requested KV cache byte count exceeds uint64_t range.";
247+
requested_cache_bytes = static_cast<uint64_t>(requested_blocks) * static_cast<uint64_t>(block_size);
248+
}
230249
const size_t available_after_live = get_free_size();
231-
TM_CHECK(requested_cache_bytes <= static_cast<unsigned __int128>(available_after_live))
250+
TM_CHECK(requested_cache_bytes <= static_cast<uint64_t>(available_after_live))
232251
<< "Insufficient memory for "
233252
<< (has_linear_prefix_checkpoints ? "hybrid prefix cache blocks and checkpoints." : "KV cache blocks.");
234253
}

0 commit comments

Comments
 (0)