diff --git a/examples/models/llama/runner/static_attention_io_manager.h b/examples/models/llama/runner/static_attention_io_manager.h index 6f631df3ff0..83c80b48e1a 100644 --- a/examples/models/llama/runner/static_attention_io_manager.h +++ b/examples/models/llama/runner/static_attention_io_manager.h @@ -64,6 +64,8 @@ class StaticKVCache { update_data_ = allocator_.allocate(update_data_size_); ET_CHECK(cache_data_ != nullptr); ET_CHECK(update_data_ != nullptr); + std::fill(cache_data_, cache_data_ + cache_data_size_, T(0)); + std::fill(update_data_, update_data_ + update_data_size_, T(0)); init_ptrs(); } @@ -186,6 +188,7 @@ class StaticKVCache { */ void reset() { std::fill(cache_pos_.begin(), cache_pos_.end(), 0); + std::fill(cache_data_, cache_data_ + cache_data_size_, T(0)); } private: @@ -613,6 +616,10 @@ class StaticAttentionIOManager { return config_.generate_full_logits ? input_len - 1 : 0; } std::copy(&tokens[i], &tokens[i + batch_len], input_buffer.begin()); + if (batch_len < input_len) { + std::fill( + input_buffer.begin() + batch_len, input_buffer.end(), TokenT(0)); + } if (!config_.generate_full_logits && config_.last_valid_token_pos_index) { last_valid_token_pos_ = batch_len - 1; set_input(