Skip to content

Commit 9e0d34d

Browse files
billmguofacebook-github-bot
authored andcommitted
clean up some runtime potential bugs
Summary: 1. Zero cache on allocation (line 65-66): std::fill on cache_data_ and update_data_ after allocator_.allocate() —eliminates uninitialized memory garbage that varies across devices. 2. Zero cache on reset (line 191): std::fill on cache_data_ in reset() — ensures stale KV cache from a previous prompt is fully cleared, not just the position counters. 3. Zero padding in last prefill chunk (line 618-621): When batch_len < input_len, fill the tail of the input buffer with zeros — prevents stale tokens from a previous chunk leaking through the embedding layer. sa_runner.cpp 4. Call runner.reset() before each prompt in the multi-prompt loop, stdin prompt loop, and stdin tokens loop —ensures the KV cache, masks, and input_pos_ are fully reset between prompts Differential Revision: D104615993
1 parent a49171d commit 9e0d34d

1 file changed

Lines changed: 7 additions & 0 deletions

File tree

examples/models/llama/runner/static_attention_io_manager.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class StaticKVCache {
6464
update_data_ = allocator_.allocate(update_data_size_);
6565
ET_CHECK(cache_data_ != nullptr);
6666
ET_CHECK(update_data_ != nullptr);
67+
std::fill(cache_data_, cache_data_ + cache_data_size_, T(0));
68+
std::fill(update_data_, update_data_ + update_data_size_, T(0));
6769
init_ptrs();
6870
}
6971

@@ -186,6 +188,7 @@ class StaticKVCache {
186188
*/
187189
void reset() {
188190
std::fill(cache_pos_.begin(), cache_pos_.end(), 0);
191+
std::fill(cache_data_, cache_data_ + cache_data_size_, T(0));
189192
}
190193

191194
private:
@@ -613,6 +616,10 @@ class StaticAttentionIOManager {
613616
return config_.generate_full_logits ? input_len - 1 : 0;
614617
}
615618
std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin());
619+
if (batch_len < input_len) {
620+
std::fill(
621+
input_buffer.begin() + batch_len, input_buffer.end(), TokenT(0));
622+
}
616623
if (!config_.generate_full_logits && config_.last_valid_token_pos_index) {
617624
last_valid_token_pos_ = batch_len - 1;
618625
set_input(

0 commit comments

Comments
 (0)