diff --git a/WORKSPACE b/WORKSPACE index 134386e2a8..44ed3bc06e 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -174,8 +174,8 @@ boost_deps() http_archive( name = "net_zlib_zlib", build_file = "@com_resdb_nexres//third_party:z.BUILD", - sha256 = "91844808532e5ce316b3c010929493c0244f3d37593afd6de04f71821d5136d9", - strip_prefix = "zlib-1.2.12", + sha256 = "9a93b2b7dfdac77ceba5a558a580e74667dd6fede4585b91eefb60f03b72df23", + strip_prefix = "zlib-1.3.1", urls = [ "https://zlib.net/fossils/zlib-1.2.12.tar.gz", "https://downloads.sourceforge.net/project/libpng/zlib/1.2.12/zlib-1.2.12.tar.gz", @@ -225,10 +225,14 @@ bind( http_archive( name = "com_zlib", - build_file = "@com_resdb_nexres//third_party:zlib.BUILD", - sha256 = "629380c90a77b964d896ed37163f5c3a34f6e6d897311f1df2a7016355c45eff", - strip_prefix = "zlib-1.2.11", - url = "https://github.com/madler/zlib/archive/v1.2.11.tar.gz", + build_file = "@com_resdb_nexres//third_party:z.BUILD", + sha256 = "9a93b2b7dfdac77ceba5a558a580e74667dd6fede4585b91eefb60f03b72df23", + strip_prefix = "zlib-1.3.1", + urls = [ + "https://zlib.net/zlib-1.3.1.tar.gz", + "https://zlib.net/fossils/zlib-1.3.1.tar.gz", + "https://github.com/madler/zlib/releases/download/v1.3.1/zlib-1.3.1.tar.gz", + ], ) http_archive( diff --git a/benchmark/protocols/raft/BUILD b/benchmark/protocols/raft/BUILD new file mode 100644 index 0000000000..a65b722406 --- /dev/null +++ b/benchmark/protocols/raft/BUILD @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +package(default_visibility = ["//visibility:private"]) + +load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") + +cc_binary( + name = "kv_server_performance", + srcs = ["kv_server_performance.cpp"], + deps = [ + "//chain/storage:memory_db", + "//executor/kv:kv_executor", + "//platform/config:resdb_config_utils", + "//platform/consensus/ordering/raft/framework:consensus", + "//service/utils:server_factory", + ], +) + +cc_binary( + name = "kv_service_tools", + srcs = ["kv_service_tools.cpp"], + deps = [ + "//common/proto:signature_info_cc_proto", + "//interface/kv:kv_client", + "//platform/config:resdb_config_utils", + ], +) diff --git a/benchmark/protocols/raft/kv_server_performance.cpp b/benchmark/protocols/raft/kv_server_performance.cpp new file mode 100644 index 0000000000..a74ef45375 --- /dev/null +++ b/benchmark/protocols/raft/kv_server_performance.cpp @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +#include "chain/storage/memory_db.h" +#include "executor/kv/kv_executor.h" +#include "platform/config/resdb_config_utils.h" +#include "platform/consensus/ordering/raft/framework/consensus.h" +#include "platform/networkstrate/service_network.h" +#include "platform/statistic/stats.h" +#include "proto/kv/kv.pb.h" + +using namespace resdb; +using namespace resdb::raft; +using namespace resdb::storage; + +void ShowUsage() { + printf(" [logging_dir]\n"); +} + +std::string GetRandomKey() { + int num1 = rand() % 10; + int num2 = rand() % 10; + return std::to_string(num1) + std::to_string(num2); +} + +int main(int argc, char** argv) { + if (argc < 3) { + ShowUsage(); + exit(0); + } + + // google::InitGoogleLogging(argv[0]); + // FLAGS_minloglevel = google::GLOG_WARNING; + + char* config_file = argv[1]; + char* private_key_file = argv[2]; + char* cert_file = argv[3]; + + if (argc >= 5) { + auto monitor_port = Stats::GetGlobalStats(5); + monitor_port->SetPrometheus(argv[4]); + } + + std::unique_ptr config = + GenerateResDBConfig(config_file, private_key_file, cert_file); + + config->RunningPerformance(true); + ResConfigData config_data = config->GetConfigData(); + + auto performance_consens = std::make_unique( + *config, std::make_unique(std::make_unique())); + performance_consens->SetupPerformanceDataFunc([]() { + KVRequest request; + request.set_cmd(KVRequest::SET); + request.set_key(GetRandomKey()); + request.set_value("helloword"); + std::string request_data; + request.SerializeToString(&request_data); + return request_data; + }); + + auto server = + std::make_unique(*config, std::move(performance_consens)); + server->Run(); +} diff --git a/benchmark/protocols/raft/kv_service_tools.cpp b/benchmark/protocols/raft/kv_service_tools.cpp new file mode 100644 index 0000000000..43627b34f4 --- /dev/null +++ b/benchmark/protocols/raft/kv_service_tools.cpp @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include + +#include "common/proto/signature_info.pb.h" +#include "interface/kv/kv_client.h" +#include "platform/config/resdb_config_utils.h" + +using resdb::GenerateReplicaInfo; +using resdb::GenerateResDBConfig; +using resdb::KVClient; +using resdb::ReplicaInfo; +using resdb::ResDBConfig; + +int main(int argc, char** argv) { + if (argc < 2) { + printf("\n"); + return 0; + } + std::string client_config_file = argv[1]; + ResDBConfig config = GenerateResDBConfig(client_config_file); + + config.SetClientTimeoutMs(100000); + + KVClient client(config); + + client.Set("start", "value"); + printf("start benchmark\n"); +} diff --git a/chain/storage/leveldb.cpp b/chain/storage/leveldb.cpp index de49376e75..47ae7a84eb 100644 --- a/chain/storage/leveldb.cpp +++ b/chain/storage/leveldb.cpp @@ -228,8 +228,10 @@ bool ResLevelDB::UpdateMetrics() { return true; } -bool ResLevelDB::Flush() { - leveldb::Status status = db_->Write(leveldb::WriteOptions(), &batch_); +bool ResLevelDB::Flush(bool should_sync) { + leveldb::WriteOptions opts = leveldb::WriteOptions(); + opts.sync = should_sync; + leveldb::Status status = db_->Write(opts, &batch_); if (status.ok()) { batch_.Clear(); return true; diff --git a/chain/storage/leveldb.h b/chain/storage/leveldb.h index 67ec7a40c2..f55e062ecc 100644 --- a/chain/storage/leveldb.h +++ b/chain/storage/leveldb.h @@ -74,7 +74,7 @@ class ResLevelDB : public Storage { bool UpdateMetrics(); - bool Flush() override; + bool Flush(bool should_sync = false) override; virtual uint64_t GetLastCheckpoint() override; diff --git a/chain/storage/mock_storage.h b/chain/storage/mock_storage.h index 6b17c32620..14681db728 100644 --- a/chain/storage/mock_storage.h +++ b/chain/storage/mock_storage.h @@ -57,7 +57,7 @@ class MockStorage : public Storage { MOCK_METHOD(ItemsType, GetAllItems, (), (override)); MOCK_METHOD(ValuesSeqType, GetAllItemsWithSeq, (), (override)); - MOCK_METHOD(bool, Flush, (), (override)); + MOCK_METHOD(bool, Flush, (bool should_sync), (override)); }; } // namespace resdb diff --git a/chain/storage/storage.h b/chain/storage/storage.h index 7a0bfe5362..90db47694b 100644 --- a/chain/storage/storage.h +++ b/chain/storage/storage.h @@ -62,7 +62,7 @@ class Storage { // Default no-op SQL execution for non-SQL backends. virtual std::string ExecuteSQL(const std::string& sql_string) { return ""; } - virtual bool Flush() { return true; }; + virtual bool Flush(bool should_sync = false) { return true; }; virtual uint64_t GetLastCheckpoint() { return 0; } diff --git a/platform/consensus/ordering/common/algorithm/protocol_base.h b/platform/consensus/ordering/common/algorithm/protocol_base.h index f8e47052a2..d180746bda 100644 --- a/platform/consensus/ordering/common/algorithm/protocol_base.h +++ b/platform/consensus/ordering/common/algorithm/protocol_base.h @@ -63,9 +63,9 @@ class ProtocolBase { } protected: - int SendMessage(int msg_type, const google::protobuf::Message& msg, + virtual int SendMessage(int msg_type, const google::protobuf::Message& msg, int node_id); - int Broadcast(int msg_type, const google::protobuf::Message& msg); + virtual int Broadcast(int msg_type, const google::protobuf::Message& msg); int Commit(const google::protobuf::Message& msg); bool IsStop(); diff --git a/platform/consensus/ordering/common/framework/consensus.h b/platform/consensus/ordering/common/framework/consensus.h index 2f2884b893..022cc58bf3 100644 --- a/platform/consensus/ordering/common/framework/consensus.h +++ b/platform/consensus/ordering/common/framework/consensus.h @@ -53,7 +53,7 @@ class Consensus : public ConsensusManager { protected: int SendMsg(int type, const google::protobuf::Message& msg, int node_id); int Broadcast(int type, const google::protobuf::Message& msg); - int ResponseMsg(const BatchUserResponse& batch_resp); + virtual int ResponseMsg(const BatchUserResponse& batch_resp); void AsyncSend(); bool IsStop(); diff --git a/platform/consensus/ordering/common/framework/performance_manager.cpp b/platform/consensus/ordering/common/framework/performance_manager.cpp index ebaf1d6ab8..ed77c36460 100644 --- a/platform/consensus/ordering/common/framework/performance_manager.cpp +++ b/platform/consensus/ordering/common/framework/performance_manager.cpp @@ -52,8 +52,8 @@ PerformanceManager::PerformanceManager( total_num_ = 0; replica_num_ = config_.GetReplicaNum(); id_ = config_.GetSelfInfo().id(); - primary_ = id_ % replica_num_; - if (primary_ == 0) primary_ = replica_num_; + primary_.store(id_ % replica_num_); + if (primary_ == 0) primary_.store(replica_num_); local_id_ = 1; sum_ = 0; } @@ -67,7 +67,17 @@ PerformanceManager::~PerformanceManager() { } } -int PerformanceManager::GetPrimary() { return primary_; } +int PerformanceManager::GetPrimary() { return primary_.load(); } + +void PerformanceManager::SetPrimary(int id) { + int curr_primary = primary_.load(); + while (id != curr_primary) { + if (primary_.compare_exchange_strong(curr_primary, id)) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": primary updated to " << id; + return; + } + } +} int PerformanceManager::NeedResponse() { return config_.GetMinClientReceiveNum(); // f+1; @@ -88,16 +98,17 @@ int PerformanceManager::StartEval() { return 0; } eval_started_ = true; - for (int i = 0; i < 100000000; ++i) { - std::unique_ptr queue_item = std::make_unique(); - queue_item->context = nullptr; - queue_item->user_request = GenerateUserRequest(); - batch_queue_.Push(std::move(queue_item)); - if (i == 2000000) { - eval_ready_promise_.set_value(true); + std::thread([&](){ + for (int i = 0; i < 100000000; ++i) { + std::unique_ptr queue_item = std::make_unique(); + queue_item->context = nullptr; + queue_item->user_request = GenerateUserRequest(); + batch_queue_.Push(std::move(queue_item)); + if (i == 2000000) { + eval_ready_promise_.set_value(true); + } } - } - LOG(WARNING) << "start eval done"; +}).detach(); return 0; } @@ -176,8 +187,8 @@ void PerformanceManager::SendResponseToClient( if (create_time > 0) { uint64_t run_time = GetCurrentTime() - create_time; LOG(ERROR) << "receive current:" << GetCurrentTime() - << " create time:" << create_time << " run time:" << run_time - << " local id:" << batch_response.local_id(); + << " create time:" << create_time << " run time:" << run_time + << " local id:" << batch_response.local_id(); global_stats_->AddLatency(run_time); } send_num_--; @@ -186,8 +197,8 @@ void PerformanceManager::SendResponseToClient( // =================== request ======================== int PerformanceManager::BatchProposeMsg() { LOG(WARNING) << "batch wait time:" << config_.ClientBatchWaitTimeMS() - << " batch num:" << config_.ClientBatchNum() - << " max txn:" << config_.GetMaxProcessTxn(); + << " batch num:" << config_.ClientBatchNum() + << " max txn:" << config_.GetMaxProcessTxn(); std::vector> batch_req; eval_ready_future_.get(); bool start = false; diff --git a/platform/consensus/ordering/common/framework/performance_manager.h b/platform/consensus/ordering/common/framework/performance_manager.h index c2dce10269..b4fbf133a8 100644 --- a/platform/consensus/ordering/common/framework/performance_manager.h +++ b/platform/consensus/ordering/common/framework/performance_manager.h @@ -39,6 +39,7 @@ class PerformanceManager { virtual ~PerformanceManager(); int StartEval(); + void SetPrimary(int id); int ProcessResponseMsg(std::unique_ptr context, std::unique_ptr request); @@ -89,7 +90,7 @@ class PerformanceManager { std::mutex response_lock_[response_set_size_]; int replica_num_; int id_; - int primary_; + std::atomic primary_; std::atomic local_id_; std::atomic sum_; }; diff --git a/platform/consensus/ordering/pbft/consensus_manager_pbft.cpp b/platform/consensus/ordering/pbft/consensus_manager_pbft.cpp index d7b3880766..9023365f49 100644 --- a/platform/consensus/ordering/pbft/consensus_manager_pbft.cpp +++ b/platform/consensus/ordering/pbft/consensus_manager_pbft.cpp @@ -53,9 +53,9 @@ ConsensusManagerPBFT::ConsensusManagerPBFT( view_change_manager_(std::make_unique( config_, checkpoint_manager_.get(), message_manager_.get(), system_info_.get(), GetBroadCastClient(), GetSignatureVerifier())), - recovery_(std::make_unique(config_, checkpoint_manager_.get(), - system_info_.get(), - message_manager_->GetStorage())), + recovery_(std::make_unique( + config_, checkpoint_manager_.get(), system_info_.get(), + message_manager_->GetStorage())), query_(std::make_unique(config_, recovery_.get(), std::move(query_executor))) { LOG(INFO) << "is running is performance mode:" diff --git a/platform/consensus/ordering/pbft/consensus_manager_pbft.h b/platform/consensus/ordering/pbft/consensus_manager_pbft.h index 4df5e9c2e3..947ede07ac 100644 --- a/platform/consensus/ordering/pbft/consensus_manager_pbft.h +++ b/platform/consensus/ordering/pbft/consensus_manager_pbft.h @@ -28,7 +28,7 @@ #include "platform/consensus/ordering/pbft/query.h" #include "platform/consensus/ordering/pbft/response_manager.h" #include "platform/consensus/ordering/pbft/viewchange_manager.h" -#include "platform/consensus/recovery/recovery.h" +#include "platform/consensus/recovery/pbft_recovery.h" #include "platform/networkstrate/consensus_manager.h" namespace resdb { @@ -84,7 +84,7 @@ class ConsensusManagerPBFT : public ConsensusManager { std::unique_ptr response_manager_; std::unique_ptr performance_manager_; std::unique_ptr view_change_manager_; - std::unique_ptr recovery_; + std::unique_ptr recovery_; Stats* global_stats_; std::queue, std::unique_ptr>> request_pending_; diff --git a/platform/consensus/ordering/pbft/query.cpp b/platform/consensus/ordering/pbft/query.cpp index 197caac485..732fa437a4 100644 --- a/platform/consensus/ordering/pbft/query.cpp +++ b/platform/consensus/ordering/pbft/query.cpp @@ -24,7 +24,7 @@ namespace resdb { -Query::Query(const ResDBConfig& config, Recovery* recovery, +Query::Query(const ResDBConfig& config, PBFTRecovery* recovery, std::unique_ptr executor) : config_(config), recovery_(recovery), diff --git a/platform/consensus/ordering/pbft/query.h b/platform/consensus/ordering/pbft/query.h index 85f2e4c566..4678fb83ef 100644 --- a/platform/consensus/ordering/pbft/query.h +++ b/platform/consensus/ordering/pbft/query.h @@ -21,13 +21,13 @@ #include "executor/common/custom_query.h" #include "platform/config/resdb_config.h" -#include "platform/consensus/recovery/recovery.h" +#include "platform/consensus/recovery/pbft_recovery.h" namespace resdb { class Query { public: - Query(const ResDBConfig& config, Recovery* recovery, + Query(const ResDBConfig& config, PBFTRecovery* recovery, std::unique_ptr executor = nullptr); virtual ~Query(); @@ -41,7 +41,7 @@ class Query { protected: ResDBConfig config_; - Recovery* recovery_; + PBFTRecovery* recovery_; std::unique_ptr custom_query_executor_; }; diff --git a/platform/consensus/ordering/raft/algorithm/BUILD b/platform/consensus/ordering/raft/algorithm/BUILD new file mode 100644 index 0000000000..9b713d7845 --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/BUILD @@ -0,0 +1,183 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +package(default_visibility = ["//platform/consensus/ordering/raft:__subpackages__"]) + +cc_library( + name = "raft", + srcs = [ + "raft.cpp", + "leaderelection_manager.cpp", + ], + hdrs = [ + "raft.h", + "leaderelection_manager.h", + ], + deps = [ + "//common:comm", + "//common/crypto:signature_verifier", + "//platform/common/queue:lock_free_queue", + "//platform/consensus/ordering/common/algorithm:protocol_base", + "//platform/consensus/ordering/raft/proto:proposal_cc_proto", + "//platform/statistic:stats", + "//platform/config:resdb_config", + "//platform/consensus/execution:system_info", + "//platform/networkstrate:replica_communicator", + "//platform/proto:viewchange_message_cc_proto", + "//platform/consensus/recovery:raft_recovery" + ], + visibility = ["//platform/consensus/ordering/raft:__subpackages__", + "//platform/consensus/recovery:__subpackages__"], +) + +cc_library( + name = "mock_raft", + hdrs = ["mock_raft.h"], + testonly = True, + deps = [ + ":raft", + ], +) + +cc_library( + name = "mock_leader_election_manager", + hdrs = ["mock_leader_election_manager.h"], + testonly = True, + deps = [ + ":raft", + ], +) + +cc_test( + name = "leader_election_test", + srcs = ["leader_election_manager_test.cpp"], + deps = [ + ":raft", + ":mock_raft", + "//platform/consensus/recovery:mock_raft_recovery", + "//platform/config:resdb_config_utils", + "//common/test:test_main" + ], + size="small" +) + +cc_test( + name = "raft_append_entries_test", + srcs = [ + "raft_append_entries_test.cpp", + "raft_tests.h", + "raft_test_util.h" + ], + copts = ["-DRAFT_TEST_MODE"], + deps = [ + ":raft", + ":mock_leader_election_manager", + "//platform/consensus/recovery:mock_raft_recovery", + "//platform/networkstrate:mock_replica_communicator", + "//common/crypto:mock_signature_verifier", + "//platform/config:resdb_config_utils", + "//common/test:test_main", + "//platform/proto:client_test_cc_proto", + ], + size="small" +) + +cc_test( + name = "raft_append_entries_response_test", + srcs = [ + "raft_append_entries_response_test.cpp", + "raft_tests.h", + "raft_test_util.h" + ], + copts = ["-DRAFT_TEST_MODE"], + deps = [ + ":raft", + ":mock_leader_election_manager", + "//platform/consensus/recovery:mock_raft_recovery", + "//platform/networkstrate:mock_replica_communicator", + "//common/crypto:mock_signature_verifier", + "//platform/config:resdb_config_utils", + "//common/test:test_main", + "//platform/proto:client_test_cc_proto", + ], + size="small" +) + +cc_test( + name = "raft_request_vote_test", + srcs = [ + "raft_request_vote_test.cpp", + "raft_tests.h", + "raft_test_util.h" + ], + copts = ["-DRAFT_TEST_MODE"], + deps = [ + ":raft", + ":mock_leader_election_manager", + "//platform/consensus/recovery:mock_raft_recovery", + "//platform/networkstrate:mock_replica_communicator", + "//common/crypto:mock_signature_verifier", + "//platform/config:resdb_config_utils", + "//common/test:test_main", + "//platform/proto:client_test_cc_proto", + ], + size="small" +) + +cc_test( + name = "raft_request_vote_response_test", + srcs = [ + "raft_request_vote_response_test.cpp", + "raft_tests.h", + "raft_test_util.h" + ], + copts = ["-DRAFT_TEST_MODE"], + deps = [ + ":raft", + ":mock_leader_election_manager", + "//platform/consensus/recovery:mock_raft_recovery", + "//platform/networkstrate:mock_replica_communicator", + "//common/crypto:mock_signature_verifier", + "//platform/config:resdb_config_utils", + "//common/test:test_main", + "//platform/proto:client_test_cc_proto", + ], + size="small" +) + +cc_test( + name = "raft_integration_test", + srcs = [ + "raft_integration_test.cpp", + "raft_tests.h", + "raft_test_util.h" + ], + copts = ["-DRAFT_TEST_MODE"], + deps = [ + ":raft", + ":mock_leader_election_manager", + "//platform/consensus/recovery:raft_recovery", + "//platform/networkstrate:mock_replica_communicator", + "//platform/consensus/checkpoint:mock_checkpoint", + "//common/crypto:mock_signature_verifier", + "//platform/config:resdb_config_utils", + "//common/test:test_main", + "//platform/proto:client_test_cc_proto", + ], + size="small" +) \ No newline at end of file diff --git a/platform/consensus/ordering/raft/algorithm/leader_election_manager_test.cpp b/platform/consensus/ordering/raft/algorithm/leader_election_manager_test.cpp new file mode 100644 index 0000000000..acb6cc5673 --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/leader_election_manager_test.cpp @@ -0,0 +1,194 @@ +#include + +#include +#include +#include + +#include "platform/config/resdb_config_utils.h" +#include "platform/consensus/ordering/raft/algorithm/leaderelection_manager.h" +#include "platform/consensus/ordering/raft/algorithm/mock_raft.h" +#include "platform/consensus/recovery/mock_raft_recovery.h" + +namespace resdb { +namespace raft { + +using ::testing::Invoke; + +ResDBConfig GenerateConfig() { + ResConfigData data; + data.set_duplicate_check_frequency_useconds(100000); + data.set_enable_viewchange(true); + return ResDBConfig({GenerateReplicaInfo(1, "127.0.0.1", 1234), + GenerateReplicaInfo(2, "127.0.0.1", 1235), + GenerateReplicaInfo(3, "127.0.0.1", 1236), + GenerateReplicaInfo(4, "127.0.0.1", 1237)}, + GenerateReplicaInfo(1, "127.0.0.1", 1234), data); +} + +class TestLeaderElectionManager : public LeaderElectionManager { + public: + TestLeaderElectionManager(const ResDBConfig& config) + : LeaderElectionManager(config) {} + uint64_t GetHeartbeatCount() { + std::lock_guard lk(cv_mutex_); + return heartbeat_count_; + } + uint64_t GetBroadcastCount() { + std::lock_guard lk(cv_mutex_); + return broadcast_count_; + } + + private: + // Overriding this is used to set the timeout timer to start an election to 50 + // ms. + uint64_t RandomInt(uint64_t min, uint64_t max) { return 50; } +}; + +class LeaderElectionManagerTest : public ::testing::Test { + protected: + LeaderElectionManagerTest() : config_(GenerateConfig()) {} + + void SetUp() override { + verifier_ = nullptr; + replica_communicator_ = nullptr; + leader_election_manager_ = + std::make_unique(config_); + mock_recovery_ = std::make_unique(config_); + mock_raft_ = std::make_unique(1, 1, 3, verifier_.get(), + leader_election_manager_.get(), + replica_communicator_.get(), + mock_recovery_.get()); + } + + void TearDown() override { + if (leader_election_manager_) { + leader_election_manager_.reset(); + } + if (mock_raft_) { + mock_raft_.reset(); + } + } + + ResDBConfig config_; + std::unique_ptr verifier_; + std::unique_ptr replica_communicator_; + std::unique_ptr leader_election_manager_; + std::unique_ptr mock_raft_; + std::unique_ptr mock_recovery_; +}; + +// Test 1: Follower timeout should trigger election. +TEST_F(LeaderElectionManagerTest, FollowerTimeoutTriggersElection) { + mock_raft_->SetRole(Role::FOLLOWER); + + std::promise election_started; + std::future election_started_future = election_started.get_future(); + + leader_election_manager_->SetRaft(mock_raft_.get()); + leader_election_manager_->MayStart(); + + EXPECT_CALL(*mock_raft_, StartElection).WillOnce(Invoke([&]() { + election_started.set_value(true); + })); + + auto status = + election_started_future.wait_for(std::chrono::milliseconds(100)); + ASSERT_EQ(status, std::future_status::ready); +} + +// Test 2: Follower should not start election before timing out. +TEST_F(LeaderElectionManagerTest, FollowerShouldNotStartElectionEarly) { + mock_raft_->SetRole(Role::FOLLOWER); + + std::promise election_started; + std::future election_started_future = election_started.get_future(); + + EXPECT_CALL(*mock_raft_, StartElection()).Times(0); + + leader_election_manager_->SetRaft(mock_raft_.get()); + leader_election_manager_->MayStart(); + + std::this_thread::sleep_for(std::chrono::milliseconds(45)); + // Since the timeout timer is set to 50 ms, StartElection should never be + // called. +} + +// Test 3: Follower receiving heartbeat should NOT trigger election. +TEST_F(LeaderElectionManagerTest, + FollowerReceivingHeartbeatDoesNotStartElection) { + mock_raft_->SetRole(Role::FOLLOWER); + + std::promise election_started; + std::future election_started_future = election_started.get_future(); + + EXPECT_CALL(*mock_raft_, StartElection()).Times(0); + + leader_election_manager_->SetRaft(mock_raft_.get()); + leader_election_manager_->MayStart(); + + std::this_thread::sleep_for(std::chrono::milliseconds(45)); + leader_election_manager_->OnHeartBeat(); + + std::this_thread::sleep_for(std::chrono::milliseconds(45)); + ASSERT_EQ(leader_election_manager_->GetHeartbeatCount(), 1); + // Since the timeout timer is set to 50 ms, StartElection should never be + // called. +} + +// Test 4: Leader timeout should send heartbeat. +TEST_F(LeaderElectionManagerTest, LeaderTimeoutSendsHeartbeat) { + mock_raft_->SetRole(Role::LEADER); + + std::promise heartbeat_sent; + std::future heartbeat_sent_future = heartbeat_sent.get_future(); + + leader_election_manager_->SetRaft(mock_raft_.get()); + leader_election_manager_->MayStart(); + + EXPECT_CALL(*mock_raft_, SendHeartBeat).WillOnce(Invoke([&]() { + heartbeat_sent.set_value(true); + })); + + auto status = heartbeat_sent_future.wait_for(std::chrono::milliseconds(105)); + ASSERT_EQ(status, std::future_status::ready); +} + +// Test 5: Leader should not send heartbeat before timing out. +TEST_F(LeaderElectionManagerTest, LeaderShouldNotSendHeartbeatEarly) { + mock_raft_->SetRole(Role::LEADER); + + std::promise heartbeat_sent; + std::future heartbeat_sent_future = heartbeat_sent.get_future(); + + EXPECT_CALL(*mock_raft_, SendHeartBeat()).Times(0); + + leader_election_manager_->SetRaft(mock_raft_.get()); + leader_election_manager_->MayStart(); + + std::this_thread::sleep_for(std::chrono::milliseconds(95)); + // Since the heartbeat timer is set to 100 ms, SendHeartBeat should never be + // called. +} + +// Test 6: Leader sending some broadcast should not be sending heartbeats. +TEST_F(LeaderElectionManagerTest, LeaderWithBroadcastDoesNotSendHeartbeat) { + mock_raft_->SetRole(Role::LEADER); + + std::promise heartbeat_sent; + std::future heartbeat_sent_future = heartbeat_sent.get_future(); + + EXPECT_CALL(*mock_raft_, SendHeartBeat()).Times(0); + leader_election_manager_->SetRaft(mock_raft_.get()); + leader_election_manager_->MayStart(); + + // Send broadcasts to reset the timer. + for (int i = 0; i < 3; i++) { + std::this_thread::sleep_for(std::chrono::milliseconds(95)); + leader_election_manager_->OnAeBroadcast(); + } + + ASSERT_EQ(leader_election_manager_->GetBroadcastCount(), 3); +} + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/leaderelection_manager.cpp b/platform/consensus/ordering/raft/algorithm/leaderelection_manager.cpp new file mode 100644 index 0000000000..d2e4f1bf03 --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/leaderelection_manager.cpp @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "platform/consensus/ordering/raft/algorithm/leaderelection_manager.h" +#include "platform/consensus/ordering/raft/algorithm/raft.h" +#include +#include +#include + +#include "common/utils/utils.h" +#include "platform/proto/viewchange_message.pb.h" + +namespace resdb { +namespace raft { + +// A manager to address View change process. +// All stuff here will be addressed in sequential by using mutex +// to make things simplier. +LeaderElectionManager::LeaderElectionManager(const ResDBConfig& config) + : config_(config), + raft_(nullptr), + started_(false), + stop_(false), + timeout_min_ms(1200), + timeout_max_ms(2400), + heartbeat_timer_(100), + heartbeat_count_(0), + //last_heartbeat_time_(std::chrono::steady_clock::now()), + broadcast_count_(0), + role_epoch_(0), + known_role_epoch_(0) { + global_stats_ = Stats::GetGlobalStats(); + //LOG(INFO) << "JIM -> " << __FUNCTION__ << ": in LeaderElectionManager constructor"; +} + +LeaderElectionManager::~LeaderElectionManager() { + stop_.store(true); + cv_.notify_all(); + + if (server_checking_timeout_thread_.joinable()) { + server_checking_timeout_thread_.join(); + } +} + +void LeaderElectionManager::MayStart() { + //LOG(INFO) << "JIM -> " << __FUNCTION__ << ": in LeaderElectionManager MayStart"; + bool expected = false; + if (!started_.compare_exchange_strong(expected, true)) { + return; + } + + if (config_.GetPublicKeyCertificateInfo() + .public_key() + .public_key_info() + .type() == CertificateKeyInfo::CLIENT) { + //LOG(INFO) << "JIM -> " << __FUNCTION__ << ": in LeaderElectionManager MayStart, Client conditional"; + LOG(ERROR) << "client type not process view change"; + return; + } + + if (config_.GetConfigData().enable_viewchange()) { + //LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Starting MonitoringElectionTimeout thread."; + server_checking_timeout_thread_ = + std::thread(&LeaderElectionManager::MonitoringElectionTimeout, this); + } +} + +void LeaderElectionManager::SetRaft(raft::Raft* raft) { + raft_ = raft; +} + +void LeaderElectionManager::OnHeartBeat() { + //auto now = std::chrono::steady_clock::now(); + //std::chrono::steady_clock::duration delta; + { + std::lock_guard lk(cv_mutex_); + heartbeat_count_++; + //delta = now - last_heartbeat_time_; + //last_heartbeat_time_ = now; + } + cv_.notify_all(); + //auto ms = std::chrono::duration_cast(delta).count(); + //LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Heartbeat received after " << ms << "ms"; +} + +void LeaderElectionManager::OnRoleChange() { + { + LOG(INFO) << "JIM -> " << __FUNCTION__; + std::lock_guard lk(cv_mutex_); + role_epoch_++; + } + cv_.notify_all(); +} + +void LeaderElectionManager::OnAeBroadcast() { + { + LOG(INFO) << "JIM -> " << __FUNCTION__; + std::lock_guard lk(cv_mutex_); + broadcast_count_++; + } + cv_.notify_all(); +} + +uint64_t LeaderElectionManager::RandomInt(uint64_t min, uint64_t max) { + static thread_local std::mt19937_64 gen(std::random_device{}()); + std::uniform_int_distribution dist(min, max); + return dist(gen); +} + +Waited LeaderElectionManager::LeaderWait() { + //LOG(INFO) << "JIM -> " << __FUNCTION__; + std::unique_lock lk(cv_mutex_); + const uint64_t broadcast_snapshot = broadcast_count_; + if (known_role_epoch_ != role_epoch_) { + known_role_epoch_ = role_epoch_; + return Waited::ROLE_CHANGE; + } + cv_.wait_for(lk, std::chrono::milliseconds(heartbeat_timer_), + [this, broadcast_snapshot] { + return (stop_.load() == true + || (known_role_epoch_ != role_epoch_) + || (broadcast_snapshot != broadcast_count_)); + }); + if (stop_.load() == true) { + return Waited::STOPPED; + } + else if (known_role_epoch_ != role_epoch_) { + known_role_epoch_ = role_epoch_; + return Waited::ROLE_CHANGE; + } + else if (broadcast_snapshot != broadcast_count_) { + return Waited::BROADCASTED; + } + else { + return Waited::TIMEOUT; + } +} + +Waited LeaderElectionManager::Wait() { + //LOG(INFO) << "JIM -> " << __FUNCTION__; + const uint64_t timeout_ms = RandomInt(timeout_min_ms, timeout_max_ms); + timeout_ms_ = timeout_ms; + std::unique_lock lk(cv_mutex_); + const uint64_t heartbeat_snapshot = heartbeat_count_; + if (known_role_epoch_ != role_epoch_) { + known_role_epoch_ = role_epoch_; + return Waited::ROLE_CHANGE; + } + cv_.wait_for(lk, std::chrono::milliseconds(timeout_ms), + [this, heartbeat_snapshot] { + return (stop_.load() == true + || (heartbeat_snapshot != heartbeat_count_) + || (known_role_epoch_ != role_epoch_)); + }); + if (stop_.load() == true) { + return Waited::STOPPED; + } + else if (known_role_epoch_ != role_epoch_) { + known_role_epoch_ = role_epoch_; + return Waited::ROLE_CHANGE; + } + else if (heartbeat_snapshot != heartbeat_count_) { + return Waited::HEARTBEAT; + } + else { + return Waited::TIMEOUT; + } +} + +// Function that is run in server_checking_timeout_thread started in MayStart(). +// Causes leaders to Heartbeat. +// Causes followers and candidates to start an election if no heartbeat received. +void LeaderElectionManager::MonitoringElectionTimeout() { + while (!stop_.load()) { + Role role = raft_->GetRoleSnapshot(); + Waited res; + std::chrono::steady_clock::time_point wait_start_time_ = std::chrono::steady_clock::now(); + bool leader = false; + if (role == Role::LEADER) { + res = LeaderWait(); + leader = true; + } + else { + res = Wait(); + } + std::chrono::steady_clock::time_point wait_end_time_ = std::chrono::steady_clock::now(); + std::chrono::steady_clock::duration delta = wait_end_time_ - wait_start_time_; + auto ms = std::chrono::duration_cast(delta).count(); + if (raft_->livenessLoggingFlag_) { + LOG(INFO) << __FUNCTION__ << ": " << (leader ? "Leader" : "") << "Wait " << ms << "ms"; + } + if (res == Waited::STOPPED) { + break; + } + else if (res == Waited::ROLE_CHANGE) { + LOG(INFO) << __FUNCTION__ << ": Role change detected"; + continue; + } + else if (res == Waited::HEARTBEAT) { + //LOG(INFO) << __FUNCTION__ << ": Heartbeat received within window"; + if (raft_->GetRoleSnapshot() == Role::LEADER) { + // A leader receiving a heartbeat would be unusual but not impossible. + LOG(WARNING) << __FUNCTION__ << " Received Heartbeat as LEADER"; + } + continue; + } + else if (res == Waited::BROADCASTED) { + if (raft_->livenessLoggingFlag_) { + LOG(INFO) << __FUNCTION__ << ": AE Broadcast reset leader heartbeat timer"; + } + continue; + } + + // Only gets here if timeout expired. + // Leaders send a new heartbeat. + if (raft_->GetRoleSnapshot() == Role::LEADER) { + raft_->SendHeartBeat(); + } + // Followers and Candidates start an election. + else { + LOG(INFO) << __FUNCTION__ << ": Heartbeat timed out after " << timeout_ms_.load() << " ms"; + raft_->StartElection(); + } + } +} + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/leaderelection_manager.h b/platform/consensus/ordering/raft/algorithm/leaderelection_manager.h new file mode 100644 index 0000000000..987cadd367 --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/leaderelection_manager.h @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +#include "platform/config/resdb_config.h" +#include "platform/consensus/execution/system_info.h" +#include "platform/proto/viewchange_message.pb.h" +#include "platform/statistic/stats.h" + +namespace resdb { +namespace raft { + +class Raft; // forward declaration + +enum class Waited { + HEARTBEAT, + STOPPED, + TIMEOUT, + ROLE_CHANGE, + BROADCASTED +}; + +class LeaderElectionManager { + public: + LeaderElectionManager(const ResDBConfig& config); + virtual ~LeaderElectionManager(); + + // If the monitor is not running, start to monitor. + void MayStart(); + void SetRaft(raft::Raft*); + // This function is called upon receiving a heartbeat + virtual void OnHeartBeat(); + virtual void OnRoleChange(); + virtual void OnAeBroadcast(); + + private: + Waited LeaderWait(); + Waited Wait(); + void MonitoringElectionTimeout(); + virtual uint64_t RandomInt(uint64_t min, uint64_t max); + + protected: + ResDBConfig config_; + Stats* global_stats_; + raft::Raft* raft_; + std::map> viewchange_request_; + std::atomic started_; + std::atomic stop_; + std::thread server_checking_timeout_thread_; + std::atomic timeout_ms_; + uint64_t timeout_min_ms; + uint64_t timeout_max_ms; + uint64_t heartbeat_timer_; + uint64_t heartbeat_count_; // Protected by cv_mutex_ + uint64_t broadcast_count_; // Protected by cv_mutex_ + //std::chrono::steady_clock::time_point last_heartbeat_time_; + uint64_t role_epoch_; // Protected by cv_mutex_ + uint64_t known_role_epoch_; // Protected by cv_mutex_ + std::mutex cv_mutex_; + std::condition_variable cv_; +}; + +} // namespace raft +} // namespace resdb + diff --git a/platform/consensus/ordering/raft/algorithm/mock_leader_election_manager.h b/platform/consensus/ordering/raft/algorithm/mock_leader_election_manager.h new file mode 100644 index 0000000000..42b4e4502e --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/mock_leader_election_manager.h @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +#include "platform/consensus/ordering/raft/algorithm/leaderelection_manager.h" + +namespace resdb { +namespace raft { + +class MockLeaderElectionManager : public LeaderElectionManager { + public: + MockLeaderElectionManager(const ResDBConfig& config) + : LeaderElectionManager(config) {} + MOCK_METHOD(void, OnRoleChange, (), (override)); + MOCK_METHOD(void, OnHeartBeat, (), (override)); + MOCK_METHOD(void, OnAeBroadcast, (), (override)); +}; + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/mock_raft.h b/platform/consensus/ordering/raft/algorithm/mock_raft.h new file mode 100644 index 0000000000..2cd58089cf --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/mock_raft.h @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +#include "platform/consensus/ordering/raft/algorithm/raft.h" + +namespace resdb { +namespace raft { + +class MockRaft : public Raft { + public: + MockRaft(int id, int f, int total_num, SignatureVerifier* verifier, + LeaderElectionManager* leaderelection_manager, + ReplicaCommunicator* replica_communicator, RaftRecovery* recovery) + : Raft(id, f, total_num, verifier, leaderelection_manager, + replica_communicator, recovery) {} + + MOCK_METHOD(void, SendHeartBeat, (), ()); + MOCK_METHOD(void, StartElection, (), ()); + MOCK_METHOD(int, Broadcast, + (int msg_type, const google::protobuf::Message& msg), (override)); + MOCK_METHOD(int, SendMessage, + (int msg_type, const google::protobuf::Message& msg, int node_id), + (override)); +}; + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/raft.cpp b/platform/consensus/ordering/raft/algorithm/raft.cpp new file mode 100644 index 0000000000..4b78b13359 --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/raft.cpp @@ -0,0 +1,1097 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "platform/consensus/ordering/raft/algorithm/raft.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "common/crypto/signature_verifier.h" +#include "common/utils/utils.h" +#include "platform/consensus/ordering/raft/proto/proposal.pb.h" +#include "platform/proto/resdb.pb.h" + +namespace resdb { +namespace raft { + +void PrintStackTrace() { + void* buffer[64]; + int n = backtrace(buffer, 64); + char** symbols = backtrace_symbols(buffer, n); + + for (int i = 0; i < n; ++i) { + LOG(INFO) << symbols[i]; + } + + free(symbols); +} + +std::ostream& operator<<(std::ostream& stream, Role role) { + const char* nameRole[] = {"FOLLOWER", "CANDIDATE", "LEADER"}; + return stream << nameRole[static_cast(role)]; +} + +std::ostream& operator<<(std::ostream& stream, TermRelation tr) { + const char* nameTR[] = {"STALE", "CURRENT", "NEW"}; + return stream << nameTR[static_cast(tr)]; +} + +uint32_t LogEntry::GetSerializedSize() const { + if (serializedSize == 0) { + serializedSize = ComputeSerializedEntrySize(); + } + return serializedSize; +} + +uint32_t LogEntry::ComputeSerializedEntrySize() const { + return entry.ByteSizeLong(); +} + +Raft::Raft(int id, int f, int total_num, SignatureVerifier* verifier, + LeaderElectionManager* leaderelection_manager, + ReplicaCommunicator* replica_communicator, RaftRecovery* recovery) + : ProtocolBase(id, f, total_num), + currentTerm_(0), + votedFor_(-1), + lastLogIndex_(-1), // This value is unsigned, but after the sentinel is + // added wraps back around to 0 + commitIndex_(0), + lastCommitted_(0), + role_(Role::FOLLOWER), + snapshot_last_index_(0), + snapshot_last_term_(0), + heartBeatsSentThisTerm_(0), + is_stop_(false), + quorum_((total_num / 2) + 1), + verifier_(verifier), + leader_election_manager_(leaderelection_manager), + replica_communicator_(replica_communicator), + recovery_(recovery) { + assert(recovery_); + id_ = id; + total_num_ = total_num; + f_ = (total_num-1)/2; + //last_ae_time_ = std::chrono::steady_clock::now(); + //last_heartbeat_time_ = std::chrono::steady_clock::now(); + + LogEntry sentinel; + sentinel.entry.set_term(0); + sentinel.entry.set_command("COMMON_PREFIX"); + AddToLog(sentinel, false); + lastLogIndex_ = 0; + + inflightVecs_.resize(total_num_ + 1); + for (auto& vec : inflightVecs_) { + vec.reserve(maxInFlightPerFollower); + } + nextIndex_.assign(total_num_ + 1, lastLogIndex_ + 1); + matchIndex_.assign(total_num_ + 1, lastLogIndex_); +} + +Raft::~Raft() { + is_stop_ = true; +} + +bool Raft::IsStop() { + return is_stop_; +} + +void Raft::SetRoleLocked(Role role) { role_ = role; } + +void Raft::SetRole(Role role) { + std::lock_guard lk(mutex_); + role_ = role; +} + +bool Raft::ReceiveTransaction(std::unique_ptr req) { + std::vector messages; + { + std::lock_guard lk(mutex_); + if (role_ != Role::LEADER) { + // Inform client proxy of new leader? + // Redirect transaction to a known leader? + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Replica is not leader, returning early"; + return false; + } + // append new transaction to log + LogEntry logEntry; + logEntry.entry.set_term(currentTerm_); + + std::string serialized; + if (!req->SerializeToString(&serialized)) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": req could not be serialized"; + return false; + } + + logEntry.entry.set_command(std::move(serialized)); + logEntry.GetSerializedSize(); + AddToLog(logEntry); + + nextIndex_[id_] = lastLogIndex_ + 1; + matchIndex_[id_] = lastLogIndex_; + + if (replicationLoggingFlag_) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Leader appended entry at index " << lastLogIndex_; + } + + // prepare fields for appendEntries message + PruneExpiredInFlightMsgsLocked(); + messages = GatherAeFieldsForBroadcastLocked(); + auto now = std::chrono::steady_clock::now(); + for (const auto& msg : messages) { + RecordNewInFlightMsgLocked(msg, now); + } + } + for (const auto& msg : messages) { + CreateAndSendAppendEntryMsg(msg); + } + leader_election_manager_->OnAeBroadcast(); + return true; +} + +bool Raft::ReceiveAppendEntries(std::unique_ptr ae) { + if (ae->leaderid() == id_) { + return false; + } + uint64_t term; + bool success = false; + bool demoted = false; + TermRelation tr; + Role initialRole; + uint64_t lastLogIndex; + auto leaderCommit = ae->leadercommitindex(); + auto leaderId = ae->leaderid(); + std::vector> eToApply; + + const char* parent_fn = __FUNCTION__; + [&]() { + std::lock_guard lk(mutex_); + // ---------- Checking term, role, prevlogindex, prevlogterm ---------- + initialRole = role_; + lastLogIndex = lastLogIndex_; + tr = TermCheckLocked(ae->term()); + if (tr == TermRelation::NEW) { + demoted = DemoteSelfLocked(ae->term()); + } + else if (role_ != Role::FOLLOWER && tr == TermRelation::CURRENT) { + demoted = DemoteSelfLocked(ae->term()); + } + + if (tr != TermRelation::STALE && role_ == Role::FOLLOWER) { + uint64_t i = ae->prevlogindex(); + + if (i <= snapshot_last_index_ || + (i < static_cast(GetLogicalLogSize()) && + ae->prevlogterm() == GetLogTermAtIndex(i))) { + success = true; + } + } + term = currentTerm_; + // Early return if we should not append + if (!success) { + return; + } + + // ---------- Appending entries ---------- + uint64_t logIdx = ae->prevlogindex() + 1; + uint64_t entriesIdx = 0; + // If we receive an entry that has already been snapshotted, that means it + // was committed, which means it must be identical to what we have. So, skip + // to the first entry after a snapshot. + if (logIdx <= snapshot_last_index_) { + entriesIdx = snapshot_last_index_ - logIdx + 1; + logIdx = snapshot_last_index_ + 1; + } + uint64_t entriesSize = static_cast(ae->entries_size()); + // check for conflicting entry terms in existing indices + // if conflict, delete suffix and short circuit out of loop + while (logIdx < GetLogicalLogSize() && entriesIdx < entriesSize) { + uint64_t term = ae->entries(entriesIdx).term(); + if (term != GetLogTermAtIndex(logIdx)) { + TruncateLog(logIdx); + + if (replicationLoggingFlag_) { + LOG(INFO) << "JIM -> " << parent_fn << ": follower saw term mismatch at index " << logIdx << ". Suffix erased from log"; + } + + break; + } + ++entriesIdx; + ++logIdx; + } + + // append remaining entries + const auto appendSize = entriesSize - entriesIdx; + std::vector log_entries_to_add; + for (uint64_t i = entriesIdx; i < entriesSize; ++i) { + log_entries_to_add.push_back(CreateLogEntry(ae->entries(i))); + } + + uint64_t firstAppendIdx = lastLogIndex_ + 1; + AddToLog(std::move(log_entries_to_add)); + lastLogIndex = lastLogIndex_; + + if (replicationLoggingFlag_ && appendSize > 0) { + if (appendSize > 1) { + LOG(INFO) << "JIM -> " << parent_fn << ": follower appended entries at indices " << firstAppendIdx << " to " << lastLogIndex_; + } + else { + LOG(INFO) << "JIM -> " << parent_fn << ": follower appended entry at index " << lastLogIndex_; + } + } + + // ---------- Try to raise commitIndex and commit entries ---------- + uint64_t prevCommitIndex = commitIndex_; + if (leaderCommit > commitIndex_) { + commitIndex_ = std::min(leaderCommit, lastLogIndex_); + + if (replicationLoggingFlag_ && commitIndex_ > prevCommitIndex) { + LOG(INFO) << "JIM -> " << parent_fn << ": Raised commitIndex_ from " + << prevCommitIndex << " to " << commitIndex_; + } + } + + // build vector to apply committed entries outside mutex + eToApply = PrepareCommitLocked(); + }(); + + /* + auto now = std::chrono::steady_clock::now(); + std::chrono::steady_clock::duration delta; + delta = now - last_ae_time_; + last_ae_time_ = now; + + + if (replicationLoggingFlag_) { + + auto ms = std::chrono::duration_cast(delta).count(); + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": AE received after " << ms << "ms"; + + } + */ + + // ---------- Outside mutex: inform leader_election_manager, apply committed entries, send response ---------- + if (demoted) { + leader_election_manager_->OnRoleChange(); + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Demoted from " + << (initialRole == Role::LEADER ? "LEADER" : "CANDIDATE") << "->FOLLOWER in term " << term; + } + + if (tr != TermRelation::STALE) { + leader_election_manager_->OnHeartBeat(); + } + + for (auto& entry : eToApply) { + commit_(*entry); + } + + AppendEntriesResponse aer; + aer.set_term(term); + aer.set_success(success); + aer.set_id(id_); + aer.set_lastlogindex(lastLogIndex); + SendMessage(MessageType::AppendEntriesResponseMsg, aer, leaderId); + + if (replicationLoggingFlag_) { + /* + if (success) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": responded success"; + } + else { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": responded failure"; + } + */ + } + return true; +} + +bool Raft::ReceiveAppendEntriesResponse(std::unique_ptr aer) { + uint64_t term; + bool demoted = false; + bool resending = false; + TermRelation tr; + Role initialRole; + std::vector> eToApply; + AeFields fields; + int followerId = aer->id(); + const char* parent_fn = __FUNCTION__; + [&]() { + std::lock_guard lk(mutex_); + initialRole = role_; + tr = TermCheckLocked(aer->term()); + if (tr == TermRelation::NEW) { + demoted = DemoteSelfLocked(aer->term()); + } + term = currentTerm_; + + if (role_ != Role::LEADER || tr == TermRelation::STALE) { + return; + } + PruneExpiredInFlightMsgsLocked(); + PruneRedundantInFlightMsgsLocked(followerId, aer->lastlogindex()); + nextIndex_[followerId] = aer->lastlogindex() + 1; + + // if successful, update matchIndex and try to commit more entries + if (aer->success()) { + // need to ensure matchIndex never decreases even if followers lastLogIndex decreases + matchIndex_[followerId] = std::max(matchIndex_[followerId], aer->lastlogindex()); + // use updated matchIndex to find new entries eligible for commit + std::vector sorted = matchIndex_; + std::sort(sorted.begin(), sorted.end(), std::greater()); + uint64_t lastReplicatedIndex = sorted[quorum_ - 1]; + // Need to check the lastReplicatedIndex contains entry from current term + if (lastReplicatedIndex > commitIndex_ && + GetLogTermAtIndex(lastReplicatedIndex) == currentTerm_) { + LOG(INFO) << "JIM -> " << parent_fn << ": Raised commitIndex_ from " + << commitIndex_ << " to " << lastReplicatedIndex; + commitIndex_ = lastReplicatedIndex; + } + // apply any newly committed entries to state machine + eToApply = PrepareCommitLocked(); + } + // if failure, or if nextIndex[i] < lastLogIndex + 1 (follower isnt caught up) + if (!aer->success() || (nextIndex_[followerId] < lastLogIndex_ + 1)) { + if (!aer->success()) { + LOG(INFO) << "AppendEntriesResponse indicates FAILURE from follower " << followerId; + LOG(INFO) << "NextIndex is: " << nextIndex_[followerId] << " their lastLogIndex is: " << aer->lastlogindex(); + } + if (aer->lastlogindex() < snapshot_last_index_) { + LOG(INFO) << "snapshot_last_index_ is: " << snapshot_last_index_; + SendInstallSnapshot(followerId); + } else if (!InFlightPerFollowerLimitReachedLocked(followerId)) { + fields = GatherAeFieldsLocked(followerId); + resending = true; + auto now = std::chrono::steady_clock::now(); + RecordNewInFlightMsgLocked(fields, now); + } + } + }(); + if (demoted) { + leader_election_manager_->OnRoleChange(); + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Demoted from " + << (initialRole == Role::LEADER ? "LEADER" : "CANDIDATE") << "->FOLLOWER in term " << term; + return false; + } + if (resending) { + CreateAndSendAppendEntryMsg(fields); + } + + for (auto& entry : eToApply) { + commit_(*entry); + } + return true; +} + +void Raft::ReceiveRequestVote(std::unique_ptr rv) { + int rvSender = rv->candidateid(); + uint64_t rvTerm = rv->term(); + + uint64_t term; + bool voteGranted = false; + bool demoted = false; + bool validCandidate = false; + int votedFor = -1; + Role initialRole; + + if (rvSender == id_) { + return; + } + + //const char* parent_fn = __FUNCTION__; + [&]() { + std::lock_guard lk(mutex_); + initialRole = role_; + // If their term is higher than ours, we accept new term, reset votedFor + // and convert to follower + TermRelation tr = TermCheckLocked(rvTerm); + if (tr == TermRelation::STALE) { + term = currentTerm_; + return; + } + else if (tr == TermRelation::NEW) { + demoted = DemoteSelfLocked(rvTerm); + } + // Then we continue voting process + term = currentTerm_; + votedFor = votedFor_; + + uint64_t lastLogTerm = getLastLogTermLocked(); + if (rv->lastlogterm() < lastLogTerm) { + return; + } + if (rv->lastlogterm() == lastLogTerm && rv->lastlogindex() < lastLogIndex_) { + return; + } + validCandidate = true; + if (votedFor_ == -1 || votedFor_ == rvSender) { + SetVotedFor(rvSender); + voteGranted = true; + } + }(); + if (demoted) { + leader_election_manager_->OnRoleChange(); + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Demoted from " + << (initialRole == Role::LEADER ? "LEADER" : "CANDIDATE") << "->FOLLOWER in term " << term; + } + if (voteGranted) { + leader_election_manager_->OnHeartBeat(); + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": voted for " << rvSender<< " in term " << term; + } + else if (validCandidate) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": did not vote for " + << rvSender<< " on term " << term << ". I already voted for " << votedFor + << ((votedFor == id_) ? " (myself)" : ""); + } + + RequestVoteResponse rvr; + rvr.set_term(term); + rvr.set_voterid(id_); + rvr.set_votegranted(voteGranted); + SendMessage(MessageType::RequestVoteResponseMsg, rvr, rvSender); +} + +void Raft::ReceiveRequestVoteResponse(std::unique_ptr rvr) { + uint64_t term = rvr->term(); + int voterId = rvr->voterid(); + bool votedYes = rvr->votegranted(); + bool demoted = false; + bool elected = false; + Role initialRole; + + const char* parent_fn = __FUNCTION__; + [&]() { + std::lock_guard lk(mutex_); + initialRole = role_; + TermRelation tr = TermCheckLocked(term); + if (tr == TermRelation::STALE) { + return; + } + else if (tr == TermRelation::NEW) { + demoted = DemoteSelfLocked(term); + return; + } + if (role_ != Role::CANDIDATE) { + return; + } + if (!votedYes) { + return; + } + bool dupe = (std::find(votes_.begin(), votes_.end(), voterId) != votes_.end()); + if (dupe) { + return; + } + votes_.push_back(voterId); + LOG(INFO) << "JIM -> " << parent_fn << ": Replica " << voterId << " voted for me. Votes: " + << votes_.size() << "/" << quorum_ << " in term " << currentTerm_; + if (votes_.size() >= quorum_) { + elected = true; + SetRoleLocked(Role::LEADER); + ClearInFlightsLocked(); + nextIndex_.assign(total_num_ + 1, lastLogIndex_ + 1); + + // make sure to set leaders own matchIndex entry to lastLogIndex + matchIndex_.assign(total_num_ + 1, 0); + matchIndex_[id_] = lastLogIndex_; + LOG(INFO) << "JIM -> " << parent_fn << ": CANDIDATE->LEADER in term " << currentTerm_; + } + }(); + if (demoted || elected) { + leader_election_manager_->OnRoleChange(); + } + if (demoted) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Demoted from " + << (initialRole == Role::LEADER ? "LEADER" : "CANDIDATE") << "->FOLLOWER in term " << term; + } + if (elected) { + SendHeartBeat(); + } +} + +Role Raft::GetRoleSnapshot() const { + std::lock_guard lk(mutex_); + return role_; +} + +// Called from LeaderElectionManager::StartElection when timeout +void Raft::StartElection() { + uint64_t currentTerm; + int candidateId; + uint64_t lastLogIndex; + uint64_t lastLogTerm; + bool roleChanged = false; + + { + std::lock_guard lk(mutex_); + if (role_ == Role::LEADER) { + LOG(WARNING) << __FUNCTION__ << ": Leader tried to start election"; + return; + } + if (role_ == Role::FOLLOWER) { + SetRoleLocked(Role::CANDIDATE); + roleChanged = true; + } + heartBeatsSentThisTerm_ = 0; + SetCurrentTermAndVotedFor(currentTerm_ + 1, id_); + votes_.clear(); + votes_.push_back(id_); + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": I voted for myself. Votes: " + << votes_.size() << "/" << quorum_ << " in term " << currentTerm_; + + currentTerm = currentTerm_; + candidateId = id_; + lastLogIndex = lastLogIndex_; + lastLogTerm = getLastLogTermLocked(); + } + if (roleChanged) { + leader_election_manager_->OnRoleChange(); + LOG(INFO) << __FUNCTION__ << ": FOLLOWER->CANDIDATE in term " << currentTerm; + } + + RequestVote rv; + rv.set_term(currentTerm); + rv.set_candidateid(candidateId); + rv.set_lastlogindex(lastLogIndex); + rv.set_lastlogterm(lastLogTerm); + Broadcast(MessageType::RequestVoteMsg, rv); +} + +void Raft::SendHeartBeat() { + auto functionStart = std::chrono::steady_clock::now(); + std::chrono::steady_clock::duration functionDelta; + + std::vector messages; + uint64_t currentTerm; + uint64_t heartBeatNum; + { + std::lock_guard lk(mutex_); + if (role_ != Role::LEADER) { + LOG(WARNING) << __FUNCTION__ << ": Non-Leader tried to start HeartBeat"; + return; + } + currentTerm = currentTerm_; + + heartBeatsSentThisTerm_++; + heartBeatNum = heartBeatsSentThisTerm_; + bool heartbeat = true; + messages = GatherAeFieldsForBroadcastLocked(heartbeat); + } + + auto msgStart = std::chrono::steady_clock::now(); + std::chrono::steady_clock::duration msgDelta; + + for (const auto& msg : messages) { + CreateAndSendAppendEntryMsg(msg); + } + + auto msgEnd = std::chrono::steady_clock::now(); + msgDelta = msgEnd - msgStart; + auto msgMs = std::chrono::duration_cast(msgDelta).count(); + + if (livenessLoggingFlag_) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": " << msgMs << " ms elapsed in CreateAndSend loop"; + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Heartbeat " << heartBeatNum << " for term " << currentTerm; + } + + auto redirectStart = std::chrono::steady_clock::now(); + std::chrono::steady_clock::duration redirectDelta; + + // Also ping client proxies that this is the leader + DirectToLeader dtl; + dtl.set_term(currentTerm); + dtl.set_leaderid(id_); + for (const auto& client : replica_communicator_->GetClientReplicas()) { + int id = client.id(); + SendMessage(DirectToLeaderMsg, dtl, id); + //LOG(INFO) << "JIM -> " << __FUNCTION__ << ": DirectToLeader " << id_ << " sent to proxy " << id; + } + + auto redirectEnd = std::chrono::steady_clock::now(); + redirectDelta = redirectEnd - redirectStart; + auto redirectMs = std::chrono::duration_cast(redirectDelta).count(); + + + auto functionEnd = std::chrono::steady_clock::now(); + functionDelta = functionEnd - functionStart; + auto functionMs = std::chrono::duration_cast(functionDelta).count(); + + if (livenessLoggingFlag_) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": " << redirectMs << " ms elapsed in redirect loop"; + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": " << functionMs << " ms elapsed in function"; + } +} + +// requires raft mutex to be held +// returns true if demoted +bool Raft::DemoteSelfLocked(uint64_t term) { + if (term > currentTerm_) { + SetCurrentTermAndVotedFor(term, -1); + } + if (role_ != Role::FOLLOWER) { + SetRoleLocked(Role::FOLLOWER); + //LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Demoted to FOLLOWER"; + return true; + } + return false; +} + +// requires raft mutex to be held +TermRelation Raft::TermCheckLocked(uint64_t term) const { + if (term < currentTerm_) { + return TermRelation::STALE; + } + else if (term == currentTerm_) { + return TermRelation::CURRENT; + } + else { + return TermRelation::NEW; + } +} + +// requires raft mutex to be held +uint64_t Raft::getLastLogTermLocked() const { + if (lastLogIndex_ <= snapshot_last_index_) { + return snapshot_last_term_; + } + + return GetLogTermAtIndex(lastLogIndex_); +} + +// requires raft mutex to be held +std::vector> Raft::PrepareCommitLocked() { + std::vector> commitVec; + uint64_t begin = lastCommitted_ + 1; + bool applying = false; + while (lastCommitted_ < commitIndex_ && + lastCommitted_ < GetLogicalLogSize() - 1) { + ++lastCommitted_; + auto command = std::make_unique(); + + if (!command->ParseFromString( + GetLogEntryAtIndex(lastCommitted_).entry.command())) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Failed to parse command"; + continue; + } + // assign seq number as log index for the request or executing transactions fails. + command->set_seq(lastCommitted_); + commitVec.push_back(std::move(command)); + applying = true; + } + + if (applying && replicationLoggingFlag_) { + if (lastCommitted_ > begin) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Applying index entries " + << begin << " to " << lastCommitted_; + } else { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Applying index entry " + << lastCommitted_; + } + } + + return commitVec; +} + +AeFields Raft::GatherAeFieldsLocked(int followerId, bool heartBeat) const { + AeFields fields{}; + LOG(INFO) << "snapshot_last_index_ is: " << snapshot_last_index_; + assert((nextIndex_[followerId] - 1 >= snapshot_last_index_) || heartBeat); + + fields.term = currentTerm_; + fields.leaderId = id_; + fields.leaderCommit = commitIndex_; + fields.prevLogIndex = nextIndex_[followerId] - 1; + fields.prevLogTerm = GetLogTermAtIndex(fields.prevLogIndex); + fields.followerId = followerId; + if (heartBeat) { + return fields; + } + uint32_t msgBytes = maxHeaderBytes; + const uint64_t firstNew = nextIndex_[followerId]; + const uint64_t limit = std::min(lastLogIndex_, (firstNew + maxEntries) - 1); + for (uint64_t i = firstNew; i <= limit; ++i) { + msgBytes += GetLogEntryAtIndex(i).GetSerializedSize(); + // Always include at least 1 entry, after that limit by maxBytes. + if (i != firstNew && msgBytes >= maxBytes) { + break; + } + LogEntry entry; + entry.entry = GetLogEntryAtIndex(i).entry; + fields.entries.push_back(entry); + } + return fields; +} + +// returns vector of tuples +// If heartBeat == true, entries[] will be empty for all messages +// else entries will each contain at most maxEntries amount of entries +// Followers will be excluded from the broadcast if they are at inflight max unless this is a heartbeat +std::vector Raft::GatherAeFieldsForBroadcastLocked(bool heartBeat) const { + assert(role_ == Role::LEADER); + std::vector fieldsVec; + fieldsVec.reserve(total_num_ - 1); + for (size_t i = 1; i <= total_num_; ++i) { + if (i == id_) { + continue; + } + if (!heartBeat && InFlightPerFollowerLimitReachedLocked(i)) { + continue; + } + if (nextIndex_[i] - 1 >= snapshot_last_index_) { + AeFields fields = GatherAeFieldsLocked(i, heartBeat); + fieldsVec.push_back(fields); + } + } + return fieldsVec; +} + +void Raft::CreateAndSendAppendEntryMsg(const AeFields& fields) { + int followerId = fields.followerId; + AppendEntries ae; + ae.set_term(fields.term); + ae.set_leaderid(fields.leaderId); + ae.set_prevlogindex(fields.prevLogIndex); + ae.set_prevlogterm(fields.prevLogTerm); + ae.set_leadercommitindex(fields.leaderCommit); + for (const auto& entry : fields.entries) { + Entry* newEntry = ae.add_entries(); + newEntry->set_term(entry.entry.term()); + newEntry->set_command(entry.entry.command()); + } + SendMessage(MessageType::AppendEntriesMsg, ae, followerId); + if (replicationLoggingFlag_) { + uint64_t entryCount = fields.entries.size(); + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Sent AE with " << entryCount << (entryCount == 1 ? " entry" : " entries"); + } +} + +LogEntry Raft::CreateLogEntry(const Entry& entry) const { + LogEntry newEntry; + newEntry.entry = entry; + return newEntry; +} + +void Raft::ClearInFlightsLocked() { + assert(role_ == Role::LEADER); + for (auto& vec : inflightVecs_) { + vec.clear(); + } +} + +void Raft::PruneExpiredInFlightMsgsLocked() { + assert(role_ == Role::LEADER); + auto now = std::chrono::steady_clock::now(); + for (size_t i = 1; i < inflightVecs_.size(); ++i) { + if (i == id_) { + continue; + } + auto& vec = inflightVecs_[i]; + if (vec.empty()) { + continue; + } + auto it = vec.begin(); + while(it != vec.end()) { + auto timeElapsed = now - it->timeSent; + if (timeElapsed >= AEResponseDeadline) { + it = vec.erase(it); + if (replicationLoggingFlag_) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Pruned expired inflight AE for follower " << i; + } + } + else { + ++it; + } + } + } +} + +void Raft::PruneRedundantInFlightMsgsLocked(int followerId, uint64_t followerLastLogIndex) { + assert(role_ == Role::LEADER); + assert(followerId > 0); + assert(static_cast(followerId) < inflightVecs_.size()); + assert(followerId != id_); + + auto& msgVec = inflightVecs_[followerId]; + if (msgVec.empty()) { + return; + } + auto it = msgVec.begin(); + while(it != msgVec.end()) { + if (it->prevLogIndexSent > followerLastLogIndex || it->lastIndexOfSegmentSent <= followerLastLogIndex) { + it = msgVec.erase(it); + if (replicationLoggingFlag_) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Pruned redundant inflight AE for follower " << followerId; + } + } + else { + ++it; + } + } +} + +void Raft::RecordNewInFlightMsgLocked(const AeFields& msg, std::chrono::steady_clock::time_point timestamp) { + if (msg.entries.empty()) { + return; + } + InFlightMsg inFlight; + inFlight.timeSent = timestamp; + inFlight.prevLogIndexSent = msg.prevLogIndex; + inFlight.lastIndexOfSegmentSent = msg.prevLogIndex + msg.entries.size(); + inflightVecs_[msg.followerId].push_back(inFlight); +} + +bool Raft::InFlightPerFollowerLimitReachedLocked(int followerId) const { + assert(role_ == Role::LEADER); + assert(followerId > 0); + assert(static_cast(followerId) < inflightVecs_.size()); + assert(followerId != id_); + + auto size = inflightVecs_[followerId].size(); + assert(size <= maxInFlightPerFollower); + return size == maxInFlightPerFollower; +} + +const LogEntry& Raft::GetLogEntryAtIndex(uint64_t index) const { + assert(index > snapshot_last_index_ && + "Tried to access entry that has been snapshotted"); + // A sentinel value is always included after a snapshot + // Example: snapshot_last_index_ = 5, we have truncated the entire log, added + // 1 entry, then log.size() == 2 with the sentinel. index could be 6, and + // snapshot_last_index_ + log.size() == 7 + assert(index < snapshot_last_index_ + log_.size() && + "Tried to access element that has not been added yet"); + return log_[index - snapshot_last_index_]; +} + +const uint64_t Raft::GetLogTermAtIndex(uint64_t index) const { + assert(index >= snapshot_last_index_ && + "Tried to access entry that has been snapshotted"); + // A sentinel value is always included after a snapshot + // Example: snapshot_last_index_ = 5, we have truncated the entire log, added + // 1 entry, then log.size() == 2 with the sentinel. index could be 6, and + // snapshot_last_index_ + log.size() == 7 + assert(index < snapshot_last_index_ + log_.size() && + "Tried to access element that has not been added yet"); + if (index == snapshot_last_index_) { + return snapshot_last_term_; + } + + return log_[index - snapshot_last_index_].entry.term(); +} + +// This would be what log.size() returns if no prefix truncation occurred. +int Raft::GetLogicalLogSize() const { + return log_.size() + snapshot_last_index_; +} + +void Raft::SetCurrentTerm(uint64_t currentTerm, bool writeMetadata) { + currentTerm_ = currentTerm; + if (writeMetadata) { + WriteMetadata(); + } +} + +void Raft::SetVotedFor(int votedFor, bool writeMetadata) { + votedFor_ = votedFor; + if (writeMetadata) { + WriteMetadata(); + } +} + +void Raft::SetCurrentTermAndVotedFor(uint64_t currentTerm, int votedFor, + bool writeMetadata) { + currentTerm_ = currentTerm; + votedFor_ = votedFor; + if (writeMetadata) { + WriteMetadata(); + } +} + +void Raft::SetSnapshotLastIndexAndTerm(uint64_t snapshot_last_index, + uint64_t snapshot_last_term, + bool writeMetadata) { + uint64_t old_snapshot_last_index = snapshot_last_index_; + snapshot_last_index_ = snapshot_last_index; + snapshot_last_term_ = snapshot_last_term; + LOG(INFO) << "setting snapshot_last_index " << snapshot_last_index + << " and snapshot_last_term" << snapshot_last_term; + if (writeMetadata) { + WriteMetadata(); + return; + } + if (old_snapshot_last_index) { + LOG(INFO) << "snapshot_last_index already set during recovery"; + return; + } + + lastLogIndex_ = snapshot_last_index_; + commitIndex_ = snapshot_last_index_; + lastCommitted_ = snapshot_last_index_; + log_[0].entry.set_term(snapshot_last_term_); +} + +uint64_t Raft::GetSnapshotLastIndex() { return snapshot_last_index_; } + +void Raft::WriteMetadata() { + recovery_->WriteMetadata(currentTerm_, votedFor_, snapshot_last_index_, + snapshot_last_term_); +} + +void Raft::AddToLog(LogEntry &logEntryToAdd, bool writeMetadata) { + lastLogIndex_++; + Entry* entry; + entry = &logEntryToAdd.entry; + if (writeMetadata) { + recovery_->AddLogEntry(entry, lastLogIndex_); + } + log_.push_back(logEntryToAdd); + assert(lastLogIndex_ == GetLogicalLogSize() - 1); +} + +void Raft::AddToLog(std::vector logEntriesToAdd, bool writeMetadata) { + if (writeMetadata) { + std::vector entries_to_add; + for (const auto& entry : logEntriesToAdd) { + entries_to_add.push_back(entry.entry); + } + recovery_->AddLogEntry(entries_to_add, lastLogIndex_ + 1); + } + + lastLogIndex_ += logEntriesToAdd.size(); + log_.reserve(log_.size() + logEntriesToAdd.size()); + log_.insert(log_.end(), std::make_move_iterator(logEntriesToAdd.begin()), + std::make_move_iterator(logEntriesToAdd.end())); + + assert(lastLogIndex_ == GetLogicalLogSize() - 1); +} + +void Raft::TruncateLog(uint64_t firstIndex, bool writeMetadata) { + assert(firstIndex > commitIndex_); + auto first = log_.begin() + (firstIndex - snapshot_last_index_); + auto last = log_.begin() + (lastLogIndex_ - snapshot_last_index_) + 1; + auto num_elements_erased = lastLogIndex_ - firstIndex + 1; + if (writeMetadata) { + TruncationRecord truncation; + truncation.set_truncate_from_index(firstIndex); + truncation.set_truncate_from_term(GetLogTermAtIndex(firstIndex)); + recovery_->TruncateLog(truncation); + } + + log_.erase(first, last); + lastLogIndex_ -= num_elements_erased; + assert(lastLogIndex_ == GetLogicalLogSize() - 1); +} + +void Raft::TruncatePrefix(uint64_t index) { + std::lock_guard lk(mutex_); + TruncatePrefixLocked(index); +} + +void Raft::TruncatePrefixLocked(uint64_t index) { + assert(index > snapshot_last_index_ && + "Tried to truncate an entry that has been snapshotted"); + assert(index <= lastCommitted_ && + "Tried to prefix truncate an element that has not been committed"); + LOG(INFO) << "Setting Snapshot last index to:" << index + 1; + + // Keep the sentinel, erase everything up to the index. + auto erase_end = log_.begin() + (index - snapshot_last_index_); + auto last_snapshotted_entry_term = GetLogTermAtIndex(index); + log_.erase(log_.begin() + 1, erase_end + 1); + assert(log_[0].entry.term() == last_snapshotted_entry_term); + SetSnapshotLastIndexAndTerm(index, last_snapshotted_entry_term); + + assert(lastLogIndex_ == GetLogicalLogSize() - 1); +} + +void Raft::SendInstallSnapshot(int followerId) {} + +/* +void Raft::ReceiveInstallSnapshot() { + +} + +void Raft::ReceiveInstallSnapshotResponse() { + +} +*/ + +void Raft::PrintDebugStateLocked() const { + std::lock_guard lk(mutex_); + PrintDebugState(); +} + +void Raft::PrintDebugState() const { + std::ostringstream oss; + + oss << "---- Raft Debug State ----\n"; + + oss << "currentTerm_: " << currentTerm_ << "\n"; + oss << "votedFor_: " << votedFor_ << "\n"; + + // log_ + oss << "log_ (size " << GetLogicalLogSize() << "): ["; + for (size_t i = 0; i < GetLogicalLogSize(); ++i) { + oss << "{term: " << GetLogTermAtIndex(i) << "}"; + if (i + 1 != GetLogicalLogSize()) oss << ", "; + } + oss << "]\n"; + + // nextIndex_ + oss << "nextIndex_: ["; + for (size_t i = 0; i < nextIndex_.size(); ++i) { + oss << nextIndex_[i]; + if (i + 1 != nextIndex_.size()) oss << ", "; + } + oss << "]\n"; + + // matchIndex_ + oss << "matchIndex_: ["; + for (size_t i = 0; i < matchIndex_.size(); ++i) { + oss << matchIndex_[i]; + if (i + 1 != matchIndex_.size()) oss << ", "; + } + oss << "]\n"; + + oss << "heartBeatsSentThisTerm_: " << heartBeatsSentThisTerm_ << "\n"; + oss << "lastLogIndex_: " << lastLogIndex_ << "\n"; + oss << "commitIndex_: " << commitIndex_ << "\n"; + oss << "lastCommitted_: " << lastCommitted_ << "\n"; + oss << "role_: " << static_cast(role_) << "\n"; + + // votes_ + oss << "votes_: ["; + for (size_t i = 0; i < votes_.size(); ++i) { + oss << votes_[i]; + if (i + 1 != votes_.size()) oss << ", "; + } + oss << "]\n"; + + oss << "--------------------------"; + + LOG(INFO) << oss.str(); +} + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/raft.h b/platform/consensus/ordering/raft/algorithm/raft.h new file mode 100644 index 0000000000..5d33142c32 --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/raft.h @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef RAFT_TEST_MODE +#include +#endif + +#include "platform/common/queue/lock_free_queue.h" +#include "platform/consensus/ordering/common/algorithm/protocol_base.h" +#include "platform/consensus/ordering/raft/algorithm/leaderelection_manager.h" +#include "platform/consensus/ordering/raft/proto/proposal.pb.h" +#include "platform/consensus/recovery/raft_recovery.h" +#include "platform/networkstrate/replica_communicator.h" +#include "platform/proto/resdb.pb.h" +#include "platform/statistic/stats.h" + +namespace resdb { +namespace raft { + +enum class Role { FOLLOWER, CANDIDATE, LEADER }; +enum class TermRelation { STALE, CURRENT, NEW }; + +class LogEntry { + public: + Entry entry; + + uint32_t GetSerializedSize() const; + uint32_t ComputeSerializedEntrySize() const; + + private: + mutable uint32_t serializedSize = 0; +}; + +struct AeFields { + uint64_t term = 0; + int leaderId = -1; + uint64_t prevLogIndex = 0; + uint64_t prevLogTerm = 0; + std::vector entries{}; + uint64_t leaderCommit = 0; + int followerId = -1; // not part of AE message itself, but needed to determine recipient +}; + +struct InFlightMsg { + std::chrono::steady_clock::time_point timeSent; + uint64_t prevLogIndexSent; + uint64_t lastIndexOfSegmentSent; +}; + +#ifdef RAFT_TEST_MODE +struct RaftStatePatch { + std::optional currentTerm; + std::optional votedFor; + std::optional commitIndex; + std::optional lastCommitted; + std::optional role; + + std::optional> log; + std::optional> nextIndex; + std::optional> matchIndex; + std::optional> votes; +}; +#endif + +class Raft : public common::ProtocolBase { + public: + Raft(int id, int f, int total_num, + SignatureVerifier* verifier, + LeaderElectionManager* leaderelection_manager, + ReplicaCommunicator* replica_communicator, + RaftRecovery* recovery + ); + ~Raft(); + + const bool replicationLoggingFlag_ = true; + const bool livenessLoggingFlag_ = false; + + virtual bool ReceiveTransaction(std::unique_ptr req); + virtual bool ReceiveAppendEntries(std::unique_ptr ae); + virtual bool ReceiveAppendEntriesResponse( + std::unique_ptr aer); + virtual void ReceiveRequestVote(std::unique_ptr rv); + virtual void ReceiveRequestVoteResponse( + std::unique_ptr rvr); + virtual void StartElection(); + virtual void SendHeartBeat(); + virtual Role GetRoleSnapshot() const; + virtual void SetRole(Role role); + virtual void PrintDebugStateLocked() const; + virtual void PrintDebugState() const; + void WriteMetadata(); + uint64_t GetSnapshotLastIndex(); + + // These functions with writeMetadata are also used to replay information upon + // recovery. So, they are called with false during recovery, and true + // everywhere else. + virtual void SetCurrentTerm(uint64_t currentTerm, bool writeMetadata = true); + virtual void SetVotedFor(int votedFor, bool writeMetadata = true); + virtual void SetCurrentTermAndVotedFor(uint64_t currentTerm, int votedFor, + bool writeMetadata = true); + void SetSnapshotLastIndexAndTerm(uint64_t snapshot_last_index, + uint64_t snapshot_last_term, + bool writeMetadata = true); + void AddToLog(LogEntry &logEntry, bool writeMetadata = true); + void AddToLog(std::vector logEntriesToAdd, + bool writeMetadata = true); + void TruncateLog(uint64_t first, bool writeMetadata = true); + void TruncatePrefix(uint64_t index); + + private: + mutable std::mutex mutex_; + + virtual TermRelation TermCheckLocked( + uint64_t term) const; // Must be called under mutex + virtual bool DemoteSelfLocked(uint64_t term); // Must be called under mutex + virtual uint64_t getLastLogTermLocked() const; // Must be called under mutex + virtual bool IsStop(); + //bool IsDuplicateLogEntry(const std::string& hash) const; // Must be called under mutex + virtual std::vector> + PrepareCommitLocked(); // Must be called under mutex + virtual AeFields GatherAeFieldsLocked(int followerId, bool heartBeat = false) + const; // Must be called under mutex + std::vector GatherAeFieldsForBroadcastLocked(bool heartBeat = false) const; // Must be called under mutex + virtual void CreateAndSendAppendEntryMsg(const AeFields& fields); + virtual LogEntry CreateLogEntry(const Entry& entry) const; + virtual void ClearInFlightsLocked(); + virtual void PruneExpiredInFlightMsgsLocked(); + virtual void PruneRedundantInFlightMsgsLocked(int followerId, + uint64_t followerLastLogIndex); + virtual void RecordNewInFlightMsgLocked( + const AeFields& msg, std::chrono::steady_clock::time_point timestamp); + virtual bool InFlightPerFollowerLimitReachedLocked(int followerId) const; + int GetLogicalLogSize() const; +#ifdef RAFT_TEST_MODE + public: +#endif + const LogEntry& GetLogEntryAtIndex(uint64_t index) const; + const uint64_t GetLogTermAtIndex(uint64_t index) const; +#ifdef RAFT_TEST_MODE + private: +#endif + void SendInstallSnapshot(int followerId); + void TruncatePrefixLocked(uint64_t index); + void SetRoleLocked(Role role); + + // Persistent state on all servers: + uint64_t currentTerm_; // Protected by mutex_ + int votedFor_; // Protected by mutex_ + std::vector log_; // Protected by mutex_ + + // Volatile state on leaders: + std::vector nextIndex_; // Protected by mutex_ + std::vector matchIndex_; // Protected by mutex_ + uint64_t heartBeatsSentThisTerm_; // Protected by mutex_ + uint64_t lastLogIndex_; // Protected by mutex_ + + // Volatile state on all servers: + uint64_t commitIndex_; // Protected by mutex_ + // lastCommitted stores the last entry that has been passed to commit_, but it + // may not yet have been executed. Raft's Consensus file holds lastApplied_ + uint64_t lastCommitted_; // Protected by mutex_ + Role role_; // Protected by mutex_ + //int leaderId_; // Protected by mutex_ + std::vector votes_; // Protected by mutex_ + std::vector> inflightVecs_; // Protected by mutex_ + //std::chrono::steady_clock::time_point last_ae_time_; + //std::chrono::steady_clock::time_point last_heartbeat_time_; // Protected by mutex_ + int64_t snapshot_last_index_; + int64_t snapshot_last_term_; + + bool is_stop_; + const uint64_t quorum_; + + // for limiting AppendEntries batch sizing + static constexpr size_t maxHeaderBytes = 64; + static constexpr size_t maxBytes = 64 * 1024; + static constexpr size_t maxEntries = 16; + static constexpr size_t maxInFlightPerFollower = 4; + static constexpr std::chrono::milliseconds AEResponseDeadline{300}; // in milliseconds + + SignatureVerifier* verifier_; + LeaderElectionManager* leader_election_manager_; + //Stats* global_stats_; + ReplicaCommunicator* replica_communicator_; + RaftRecovery* recovery_; + +#ifdef RAFT_TEST_MODE + public: + void SetStateForTest(RaftStatePatch patch) { + std::lock_guard lk(mutex_); + if (patch.currentTerm) currentTerm_ = *patch.currentTerm; + if (patch.votedFor) votedFor_ = *patch.votedFor; + if (patch.commitIndex) commitIndex_ = *patch.commitIndex; + if (patch.lastCommitted) lastCommitted_ = *patch.lastCommitted; + if (patch.role) role_ = *patch.role; + + if (patch.log) { + log_ = *patch.log; + lastLogIndex_ = log_.empty() ? 0 : log_.size() - 1; + } + + if (patch.nextIndex) nextIndex_ = *patch.nextIndex; + if (patch.matchIndex) matchIndex_ = *patch.matchIndex; + if (patch.votes) votes_ = *patch.votes; + } + + uint64_t GetCurrentTerm() const { + std::lock_guard lock(mutex_); + return currentTerm_; + } + + int GetVotedFor() const { + std::lock_guard lock(mutex_); + return votedFor_; + } + + const std::vector& GetLog() const { + std::lock_guard lock(mutex_); + return log_; + } + + void PrintLog(std::ostream& os) const { + os << "Log entries (count = " << log_.size() << "):\n"; + + for (size_t i = 0; i < log_.size(); ++i) { + const auto& entry = log_[i]; + + os << " [" << i << "] " + << "term=" << entry.entry.term() << ", command=\"" + << entry.entry.command() << "\"" + << ", serializedSize=" << entry.GetSerializedSize() << "\n"; + } + } + + size_t GetLogSize() const { + std::lock_guard lock(mutex_); + return log_.size(); + } + + uint64_t GetLastLogIndexFromLog() const { + std::lock_guard lock(mutex_); + return log_.empty() ? 0 : log_.size() - 1; + } + + std::vector GetNextIndex() const { + std::lock_guard lock(mutex_); + return nextIndex_; + } + + std::vector GetMatchIndex() const { + std::lock_guard lock(mutex_); + return matchIndex_; + } + + uint64_t GetHeartBeatsSentThisTerm() const { + std::lock_guard lock(mutex_); + return heartBeatsSentThisTerm_; + } + + uint64_t GetLastLogIndex() const { + std::lock_guard lock(mutex_); + return lastLogIndex_; + } + + uint64_t GetCommitIndex() const { + std::lock_guard lock(mutex_); + return commitIndex_; + } + + uint64_t GetLastCommitted() const { + std::lock_guard lock(mutex_); + return lastCommitted_; + } + + Role GetRole() const { + std::lock_guard lock(mutex_); + return role_; + } + + std::vector GetVotes() const { + std::lock_guard lock(mutex_); + return votes_; + } + + std::vector> GetInFlightVecs() const { + std::lock_guard lock(mutex_); + return inflightVecs_; + } + +#endif +}; + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/raft_append_entries_response_test.cpp b/platform/consensus/ordering/raft/algorithm/raft_append_entries_response_test.cpp new file mode 100644 index 0000000000..aaad9e8d8f --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/raft_append_entries_response_test.cpp @@ -0,0 +1,235 @@ +#include "platform/consensus/ordering/raft/algorithm/raft_tests.h" + +namespace resdb { +namespace raft { + +// Test 1: A leader receiving an AppendEntriesResponse success and updating the +// follower's matchIndex. +TEST_F(RaftTest, LeaderReceivesAppendEntriesResponseSuccess) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + + AppendEntriesResponse aeResponse; + aeResponse.set_success(true); + aeResponse.set_term(1); + aeResponse.set_id(2); + aeResponse.set_lastlogindex(2); + + raft_->SetStateForTest({.currentTerm = 1, + .commitIndex = 0, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {0, "Transaction 1"}, + {0, "Transaction 2"}, + }, + true), + .matchIndex = std::vector{0, 2, 0, 0, 0}}); + + bool success = raft_->ReceiveAppendEntriesResponse( + std::make_unique(aeResponse)); + EXPECT_TRUE(success); + EXPECT_THAT(raft_->GetMatchIndex(), ::testing::ElementsAre(0, 2, 2, 0, 0)); +} + +// Test 2: A leader receiving an AppendEntriesResponse from a follower that in a +// newer term. +TEST_F(RaftTest, LeaderReceivesAppendEntriesResponseFromNewerTerm) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(1); + + raft_->SetStateForTest({ + .currentTerm = 1, + .role = Role::LEADER, + }); + + AppendEntriesResponse aeResponse; + aeResponse.set_success(false); + aeResponse.set_term(2); + + bool success = raft_->ReceiveAppendEntriesResponse( + std::make_unique(aeResponse)); + EXPECT_FALSE(success); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::FOLLOWER); +} + +// Test 3: A leader receiving an AppendEntriesResponse success, updating the +// follower's matchIndex, and committing a new entry. +TEST_F(RaftTest, LeaderReceivesAppendEntriesResponseSuccessAndCommits) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_commit, Commit(_)).Times(1); + + AppendEntriesResponse aeResponse; + aeResponse.set_success(true); + aeResponse.set_term(1); + aeResponse.set_id(2); + aeResponse.set_lastlogindex(2); + + raft_->SetStateForTest({.currentTerm = 1, + .commitIndex = 0, + .lastCommitted = 0, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {1, "Transaction 1"}, + {1, "Transaction 2"}, + }, + true), + .nextIndex = std::vector{1, 2, 2, 2, 2}, + .matchIndex = std::vector{0, 2, 0, 1, 0}}); + + bool success = raft_->ReceiveAppendEntriesResponse( + std::make_unique(aeResponse)); + EXPECT_TRUE(success); + EXPECT_THAT(raft_->GetMatchIndex(), ::testing::ElementsAre(0, 2, 2, 1, 0)); + EXPECT_EQ(raft_->GetCommitIndex(), 1); +} + +// Test 4: A leader receiving an AppendEntriesResponse success and catching up a +// follower that is behind. +TEST_F(RaftTest, LeaderCatchesUpFollowerThatIsBehind) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& ae = dynamic_cast(msg); + EXPECT_EQ(ae.entries_size(), 1); + // TODO: Use serialized string instead of manually doing it. + EXPECT_EQ(ae.entries(0).command(), "\n\rTransaction 2"); + EXPECT_EQ(node_id, 2); + return 0; + })); + + AppendEntriesResponse aeResponse; + aeResponse.set_success(true); + aeResponse.set_term(1); + aeResponse.set_id(2); + aeResponse.set_lastlogindex(1); + + raft_->SetStateForTest({ + .currentTerm = 1, + .commitIndex = 0, + .lastCommitted = 0, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {1, "Transaction 1"}, + {1, "Transaction 2"}, + }, + true), + }); + + bool success = raft_->ReceiveAppendEntriesResponse( + std::make_unique(aeResponse)); + EXPECT_TRUE(success); +} + +// Test 5: A leader receiving an AppendEntriesResponse Failure and catching up a +// follower that is behind. +TEST_F(RaftTest, LeaderCatchesUpFollowerThatIsBehindFailure) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& ae = dynamic_cast(msg); + // TODO: Use serialized string instead of manually doing it. + EXPECT_EQ(ae.entries(0).command(), "\n\rTransaction 1"); + EXPECT_EQ(ae.entries(1).command(), "\n\rTransaction 2"); + EXPECT_EQ(ae.entries_size(), 2); + EXPECT_EQ(node_id, 2); + return 0; + })); + + AppendEntriesResponse aeResponse; + aeResponse.set_success(false); + aeResponse.set_term(1); + aeResponse.set_id(2); + aeResponse.set_lastlogindex(0); + + raft_->SetStateForTest({ + .currentTerm = 1, + .commitIndex = 0, + .lastCommitted = 0, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {1, "Transaction 1"}, + {1, "Transaction 2"}, + }, + true), + }); + + bool success = raft_->ReceiveAppendEntriesResponse( + std::make_unique(aeResponse)); + EXPECT_TRUE(success); +} + +// Test 6: A follower ignores an AppendEntriesResponse. +TEST_F(RaftTest, FollowerIgnoresAppendEntriesResponse) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)).Times(0); + + AppendEntriesResponse aeResponse; + aeResponse.set_term(1); + + raft_->SetStateForTest({ + .currentTerm = 1, + .role = Role::FOLLOWER, + }); + + bool success = raft_->ReceiveAppendEntriesResponse( + std::make_unique(aeResponse)); + EXPECT_TRUE(success); +} + +// Test 7: A leader ignores an AppendEntriesResponse from an outdated term. +TEST_F(RaftTest, LeaderIgnoresAppendEntriesResponseFromOutdatedTerm) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)).Times(0); + + AppendEntriesResponse aeResponse; + aeResponse.set_term(1); + + raft_->SetStateForTest({ + .currentTerm = 2, + .role = Role::LEADER, + }); + + bool success = raft_->ReceiveAppendEntriesResponse( + std::make_unique(aeResponse)); + EXPECT_TRUE(success); +} + +// Test 8: A leader does not advance its commit index from a previous term if it +// has not replicated an entry from its current term. +TEST_F(RaftTest, + LeaderReceivesAppendEntriesResponseSuccessAndDoesNotCommitOldTerm) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_commit, Commit(_)).Times(0); + + AppendEntriesResponse aeResponse; + aeResponse.set_success(true); + aeResponse.set_term(1); + aeResponse.set_id(2); + aeResponse.set_lastlogindex(2); + + raft_->SetStateForTest({.currentTerm = 1, + .commitIndex = 0, + .lastCommitted = 0, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {0, "Transaction 1"}, + {0, "Transaction 2"}, + }, + true), + .nextIndex = std::vector{0, 2, 2, 2, 2}, + .matchIndex = std::vector{0, 2, 0, 1, 0}}); + + bool success = raft_->ReceiveAppendEntriesResponse( + std::make_unique(aeResponse)); + EXPECT_TRUE(success); + EXPECT_THAT(raft_->GetMatchIndex(), ::testing::ElementsAre(0, 2, 2, 1, 0)); + EXPECT_EQ(raft_->GetCommitIndex(), 0); +} + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/raft_append_entries_test.cpp b/platform/consensus/ordering/raft/algorithm/raft_append_entries_test.cpp new file mode 100644 index 0000000000..fbbba46b35 --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/raft_append_entries_test.cpp @@ -0,0 +1,784 @@ +#include "platform/consensus/ordering/raft/algorithm/raft_tests.h" + +namespace resdb { +namespace raft { + +// Test 1: A follower receiving a client transaction should reject it. +TEST_F(RaftTest, FollowerRejectsClientTransaction) { + EXPECT_CALL(mock_call, Call(_, _, _)).Times(0); + EXPECT_CALL(mock_broadcast, Broadcast(_, _)).Times(0); + + auto req = std::make_unique(); + req->set_seq(1); + raft_->SetStateForTest({ + .currentTerm = 0, + .role = Role::FOLLOWER, + .log = CreateLogEntries({}, true), + }); + + bool success = raft_->ReceiveTransaction(std::move(req)); + EXPECT_FALSE(success); +} + +// Test 2: A leader receiving a client transaction should send an AppendEntries +// to all other replicas. +TEST_F(RaftTest, LeaderSendsAppendEntriesUponClientTransaction) { + EXPECT_CALL(mock_call, Call(_, _, _)).Times(3); + EXPECT_CALL(*leader_election_manager_, OnAeBroadcast()).Times(1); + + auto req = std::make_unique(); + req->set_seq(1); + raft_->SetStateForTest({ + .currentTerm = 0, + .role = Role::LEADER, + .log = CreateLogEntries({}, true), + }); + + bool success = raft_->ReceiveTransaction(std::move(req)); + EXPECT_TRUE(success); +} + +// Test 3: Sent AppendEntries should be based on the follower's nextIndex. +TEST_F(RaftTest, LeaderSendsAppendEntriesBasedOnNextIndex) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& ae = dynamic_cast(msg); + EXPECT_EQ(node_id, 2); + EXPECT_EQ(ae.prevlogindex(), 2); + EXPECT_EQ(ae.entries().size(), 3); + return 0; + })) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& ae = dynamic_cast(msg); + EXPECT_EQ(node_id, 3); + EXPECT_EQ(ae.prevlogindex(), 1); + EXPECT_EQ(ae.entries().size(), 4); + return 0; + })) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& ae = dynamic_cast(msg); + EXPECT_EQ(node_id, 4); + EXPECT_EQ(ae.prevlogindex(), 0); + EXPECT_EQ(ae.entries().size(), 5); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnAeBroadcast()).Times(1); + + raft_->SetStateForTest({.currentTerm = 0, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + {0, "Term 0 Transaction 2"}, + {0, "Term 0 Transaction 3"}, + {0, "Term 0 Transaction 4"}, + }, + true), + .nextIndex = std::vector{1, 4, 3, 2, 1}}); + + auto req = std::make_unique(); + req->set_seq(5); + + bool success = raft_->ReceiveTransaction(std::move(req)); + EXPECT_TRUE(success); +} + +// Test 4: A follower receiving 1 AppendEntries with multiple entries that it +// can accept. +TEST_F(RaftTest, FollowerAddsAppendEntriesWithMultipleEntries) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_TRUE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 3); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/0, + /*leaderId=*/2, + /*prevLogIndex=*/0, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {0, "Transaction 1"}, + {0, "Transaction 2"}, + {0, "Transaction 3"}, + }), + /*leaderCommit=*/0, + /*followerId=*/1); + + auto aemessage = CreateAeMessage(aefields); + raft_->SetStateForTest({ + .currentTerm = 0, + .role = Role::FOLLOWER, + .log = CreateLogEntries({}, true), + }); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); +} + +// Test 5: A follower receiving multiple AppendEntries that it can accept. +TEST_F(RaftTest, FollowerAddsMultipleAppendEntries) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_TRUE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 1); + return 0; + })) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_TRUE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 2); + return 0; + })) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_TRUE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 3); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(3); + + auto aefields1 = CreateAeFields( + /*term=*/0, + /*leaderId=*/2, + /*prevLogIndex=*/0, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {0, "Transaction 1"}, + }), + /*leaderCommit=*/0, + /*followerId=*/1); + + auto aefields2 = CreateAeFields( + /*term=*/0, + /*leaderId=*/2, + /*prevLogIndex=*/1, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {0, "Transaction 2"}, + }), + /*leaderCommit=*/0, + /*followerId=*/1); + + auto aefields3 = CreateAeFields( + /*term=*/0, + /*leaderId=*/2, + /*prevLogIndex=*/2, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {0, "Transaction 3"}, + }), + /*leaderCommit=*/0, + /*followerId=*/1); + + auto aemessage1 = CreateAeMessage(aefields1); + auto aemessage2 = CreateAeMessage(aefields2); + auto aemessage3 = CreateAeMessage(aefields3); + + raft_->SetStateForTest({ + .currentTerm = 0, + .role = Role::FOLLOWER, + .log = CreateLogEntries({}, true), + }); + + bool success1 = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage1))); + EXPECT_TRUE(success1); + + bool success2 = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage2))); + EXPECT_TRUE(success2); + + bool success3 = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage3))); + EXPECT_TRUE(success3); +} + +// Test 6: A follower rejects Append Entries because its own entry at +// prevLogIndex does not have the same term. +TEST_F(RaftTest, FollowerRejectsMismatchedTermAtPrevLogIndex) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_FALSE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 1); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/0, + /*leaderId=*/2, + /*prevLogIndex=*/1, + /*prevLogTerm=*/2, + /*entries=*/ + CreateLogEntries({ + {2, "Term 2 Transaction 1"}, + }), + /*leaderCommit=*/0, + /*followerId=*/1); + + raft_->SetStateForTest({ + .currentTerm = 0, + .role = Role::FOLLOWER, + .log = CreateLogEntries( + { + {1, "Term 1 Transaction 1"}, + }, + true), + }); + + auto aemessage = CreateAeMessage(aefields); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); +} + +// Test 7: A follower rejects Append Entries because it does not have an entry at +// prevLogIndex. +TEST_F(RaftTest, FollowerRejectsMissingIndex) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_FALSE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 0); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/0, + /*leaderId=*/2, + /*prevLogIndex=*/1, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {0, "Transaction 2"}, + }), + /*leaderCommit=*/0, + /*followerId=*/1); + + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 0, + .role = Role::FOLLOWER, + .log = CreateLogEntries({}, true), + }); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); +} + +// Test 8: A follower receiving 1 AppendEntries with multiple entries and +// needing to truncate part of its log. +TEST_F(RaftTest, FollowerAddsAppendEntriesAndTruncatesLog) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_TRUE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 3); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/1, + /*leaderId=*/2, + /*prevLogIndex=*/1, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {1, "Term 1 Transaction 1"}, + {1, "Term 1 Transaction 2"}, + }), + /*leaderCommit=*/0, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 0, + .role = Role::FOLLOWER, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, // index 1 + {0, "Term 0 Transaction 2"}, // mismatched entry will be removed + }, + true), + }); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + + const auto& raft_log = raft_->GetLog(); + EXPECT_EQ(raft_log[0].entry.term(), 0); + EXPECT_EQ(raft_log[0].entry.command(), "COMMON_PREFIX"); + EXPECT_EQ(raft_log[1].entry.term(), 0); + // TODO: Use serialized string instead of manually doing it. + EXPECT_EQ(raft_log[1].entry.command(), "\n\x14Term 0 Transaction 1"); + EXPECT_EQ(raft_log[2].entry.term(), 1); + EXPECT_EQ(raft_log[2].entry.command(), "\n\x14Term 1 Transaction 1"); + EXPECT_EQ(raft_log[3].entry.term(), 1); + EXPECT_EQ(raft_log[3].entry.command(), "\n\x14Term 1 Transaction 2"); + EXPECT_TRUE(success); +} + +// Test 9: A follower increases its commitIndex. +TEST_F(RaftTest, FollowerIncreasesCommitIndex) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_TRUE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 5); + return 0; + })); + EXPECT_CALL(mock_commit, Commit(_)).Times(2); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/1, + /*leaderId=*/2, + /*prevLogIndex=*/5, + /*prevLogTerm=*/1, + /*entries=*/CreateLogEntries({}), + /*leaderCommit=*/3, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 1, + .commitIndex = 1, + .lastCommitted = 1, + .role = Role::FOLLOWER, + .log = CreateLogEntries( + { + {1, "Term 1 Transaction 1"}, + {1, "Term 1 Transaction 2"}, + {1, "Term 1 Transaction 3"}, + {1, "Term 1 Transaction 4"}, + {1, "Term 1 Transaction 5"}, + }, + true), + }); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + + EXPECT_TRUE(success); + EXPECT_EQ(raft_->GetCommitIndex(), 3); +} + +// Test 10: A follower increases its commitIndex, but not past its own log size. +TEST_F(RaftTest, FollowerIncreasesCommitIndexCappedAtLogSize) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_TRUE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 5); + return 0; + })); + EXPECT_CALL(mock_commit, Commit(_)).Times(4); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/1, + /*leaderId=*/2, + /*prevLogIndex=*/5, + /*prevLogTerm=*/1, + /*entries=*/CreateLogEntries({}), + /*leaderCommit=*/7, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 1, + .commitIndex = 1, + .lastCommitted = 1, + .role = Role::FOLLOWER, + .log = CreateLogEntries( + { + {1, "Term 1 Transaction 1"}, + {1, "Term 1 Transaction 2"}, + {1, "Term 1 Transaction 3"}, + {1, "Term 1 Transaction 4"}, + {1, "Term 1 Transaction 5"}, + }, + true), + }); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + + EXPECT_TRUE(success); + EXPECT_EQ(raft_->GetCommitIndex(), 5); +} + +// Test 11: A candidate rejecting an AppendEntries from an outdated term and +// staying candidate. +TEST_F(RaftTest, CandidateRejectsAppendEntriesFromOutdatedTerm) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_FALSE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 0); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + auto aefields = CreateAeFields( + /*term=*/1, + /*leaderId=*/2, + /*prevLogIndex=*/0, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {1, "Transaction 1"}, + {1, "Transaction 2"}, + {1, "Transaction 3"}, + }), + /*leaderCommit=*/0, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 2, + .role = Role::CANDIDATE, + .log = CreateLogEntries({}, true), + }); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); +} + +// Test 12: A candidate rejecting an AppendEntries because their log is further +// behind, but it is in the same term so they still demote. +TEST_F(RaftTest, CandidateRejectsAppendEntriesFromSameTerm) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_FALSE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 1); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(1); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/2, + /*leaderId=*/2, + /*prevLogIndex=*/2, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {2, "Transaction 1"}, + {2, "Transaction 2"}, + {2, "Transaction 3"}, + }), + /*leaderCommit=*/0, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 2, + .role = Role::CANDIDATE, + .log = CreateLogEntries({{1, "Old Transaction 1"}}, true), + }); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); +} + +// Test 13: A candidate receiving an AppendEntries it can accept from a newer +// term. +TEST_F(RaftTest, CandidateReceivesNewerTermWithAppendEntriesItCanAccept) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_TRUE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 3); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(1); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/2, + /*leaderId=*/2, + /*prevLogIndex=*/2, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {2, "Transaction 1"}, + }), + /*leaderCommit=*/2, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 1, + .lastCommitted = 2, + .role = Role::CANDIDATE, + .log = CreateLogEntries( + { + {0, "old-1"}, + {0, "old-2"}, + }, + true), + }); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::FOLLOWER); +} + +// Test 14: A candidate receiving an AppendEntries that it can accept from the +// same term but further along. +TEST_F(RaftTest, CandidateReceivesSameTermWithAppendEntriesItCanAccept) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_TRUE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 3); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(1); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/2, + /*leaderId=*/2, + /*prevLogIndex=*/2, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {2, "Transaction 1"}, + }), + /*leaderCommit=*/2, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 2, + .lastCommitted = 2, + .role = Role::CANDIDATE, + .log = CreateLogEntries( + { + {0, "old-1"}, + {0, "old-2"}, + }, + true), + }); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::FOLLOWER); +} + +// Test 15: A follower receiving a leaderCommit whose index is less than its own +// commitIndex does not lower its commitIndex. +TEST_F(RaftTest, FollowerWillNotLowerCommitIndex) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce( + ::testing::Invoke([](int type, const google::protobuf::Message& msg, + int node_id) { return 0; })); + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/1, + /*leaderId=*/2, + /*prevLogIndex=*/0, + /*prevLogTerm=*/2, + /*entries=*/ + CreateLogEntries({}), + /*leaderCommit=*/2, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 0, + .commitIndex = 4, + .lastCommitted = 0, + .role = Role::FOLLOWER, + .log = CreateLogEntries( + { + {0, "Transaction 1"}, + {0, "Transaction 2"}, + }, + true), + }); + + raft_->PrintDebugStateLocked(); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); +} + +// Test 16: A leader ignores an AppendEntries from itself +TEST_F(RaftTest, LeaderIgnoresAppendEntriesFromSelf) { + EXPECT_CALL(mock_call, Call(_, _, _)).Times(0); + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + auto aefields = CreateAeFields( + /*term=*/0, + /*leaderId=*/1, + /*prevLogIndex=*/0, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({ + {0, "Transaction 1"}, + }), + /*leaderCommit=*/0, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 0, + .lastCommitted = 0, + .role = Role::LEADER, + .log = CreateLogEntries({}, true), + }); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_FALSE(success); +} + +// Test 17: A follower receiving a heartbeat will advance its commit index. +TEST_F(RaftTest, FollowerAdvancesCommitIndexOnHeartbeat) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce( + ::testing::Invoke([](int type, const google::protobuf::Message& msg, + int node_id) { return 0; })); + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + auto aefields = CreateAeFields( + /*term=*/0, + /*leaderId=*/2, + /*prevLogIndex=*/2, + /*prevLogTerm=*/0, + /*entries=*/ + CreateLogEntries({}), + /*leaderCommit=*/2, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft_->SetStateForTest({ + .currentTerm = 0, + .commitIndex = 0, + .lastCommitted = 0, + .role = Role::FOLLOWER, + .log = CreateLogEntries( + { + {0, "Transaction 1"}, + {0, "Transaction 2"}, + }, + true), + }); + + raft_->PrintDebugStateLocked(); + + bool success = raft_->ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); + EXPECT_EQ(raft_->GetCommitIndex(), 2); +} + +// Test 17: A leader correctly sends a heartbeat. +TEST_F(RaftTest, LeaderCorrectlySendsHeartbeat) { + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& ae = dynamic_cast(msg); + EXPECT_EQ(node_id, 2); + EXPECT_EQ(ae.prevlogindex(), 2); + EXPECT_EQ(ae.entries().size(), 0); + return 0; + })) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& ae = dynamic_cast(msg); + EXPECT_EQ(node_id, 3); + EXPECT_EQ(ae.prevlogindex(), 1); + EXPECT_EQ(ae.entries().size(), 0); + return 0; + })) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& ae = dynamic_cast(msg); + EXPECT_EQ(node_id, 4); + EXPECT_EQ(ae.prevlogindex(), 0); + EXPECT_EQ(ae.entries().size(), 0); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + raft_->SetStateForTest({.currentTerm = 1, + .votedFor = 1, + .commitIndex = 0, + .lastCommitted = 0, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {0, "Transaction 1"}, + {1, "Transaction 2"}, + }, + true), + .nextIndex = std::vector{1, 4, 3, 2, 1}, + .matchIndex = std::vector{0, 2, 0, 1, 0}, + .votes = std::vector{1, 3, 2}}); + + raft_->SendHeartBeat(); + + EXPECT_EQ(raft_->GetCurrentTerm(), 1); + EXPECT_EQ(raft_->GetVotedFor(), 1); + EXPECT_EQ(raft_->GetCommitIndex(), 0); + EXPECT_EQ(raft_->GetLastCommitted(), 0); + EXPECT_EQ(raft_->GetRole(), Role::LEADER); + auto log = raft_->GetLog(); + // Maybe check that the log itself is equal + EXPECT_EQ(log.size(), 3); + EXPECT_EQ(raft_->GetLastLogIndex(), log.size() - 1); + EXPECT_THAT(raft_->GetNextIndex(), ::testing::ElementsAre(1, 4, 3, 2, 1)); + EXPECT_THAT(raft_->GetMatchIndex(), ::testing::ElementsAre(0, 2, 0, 1, 0)); + EXPECT_THAT(raft_->GetVotes(), ::testing::ElementsAre(1, 3, 2)); +} + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/raft_integration_test.cpp b/platform/consensus/ordering/raft/algorithm/raft_integration_test.cpp new file mode 100644 index 0000000000..61841dc0ad --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/raft_integration_test.cpp @@ -0,0 +1,480 @@ +// raft_integration_test.cpp +// +// Integration test: Raft state correctly restored after RecoverFromLogs(). +// Uses a real RaftRecovery (seeded with WAL data) and a real Raft. + +#include +#include + +#include + +#include "platform/consensus/checkpoint/mock_checkpoint.h" +#include "platform/consensus/ordering/raft/algorithm/raft_test_util.h" +#include "platform/consensus/recovery/raft_recovery.h" + +namespace resdb { +namespace raft { + +using resdb::raft::test_utils::CreateAeFields; +using resdb::raft::test_utils::CreateAeMessage; +using resdb::raft::test_utils::CreateLogEntries; +using resdb::raft::test_utils::GenerateConfig; +using resdb::raft::test_utils::MockBroadcastFunction; +using resdb::raft::test_utils::MockCommitFunction; +using resdb::raft::test_utils::MockSendMessageFunction; +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::Invoke; + +namespace { + +const std::string kLogPath = "./log/raft_integration_test_log"; + +ResDBConfig MakeConfig() { + ResConfigData data; + data.set_recovery_enabled(true); + data.set_recovery_path(kLogPath); + data.set_recovery_buffer_size(1024); + data.set_recovery_ckpt_time_s(3); + return ResDBConfig({GenerateReplicaInfo(1, "127.0.0.1", 1234), + GenerateReplicaInfo(2, "127.0.0.1", 1235), + GenerateReplicaInfo(3, "127.0.0.1", 1236), + GenerateReplicaInfo(4, "127.0.0.1", 1237)}, + GenerateReplicaInfo(1, "127.0.0.1", 1234), data); +} + +// Mirrors what Consensus::RecoverFromLogs() does. +void RecoverFromLogs(RaftRecovery& recovery, Raft& raft) { + recovery.ReadLogs( + [&](const RaftMetadata& metadata) { + LOG(INFO) << "loading metadata file: term: " << metadata.current_term + << " votedFor: " << metadata.voted_for + << " snapshot_last_index: " << metadata.snapshot_last_index + << " snapshot_last_term: " << metadata.snapshot_last_term; + raft.SetCurrentTerm(metadata.current_term, /*writeMetadata=*/false); + raft.SetVotedFor(metadata.voted_for, /*writeMetadata=*/false); + raft.SetSnapshotLastIndexAndTerm(metadata.snapshot_last_index, + metadata.snapshot_last_term, + /*writeMetadata=*/false); + }, + [&](std::unique_ptr record) { + LOG(INFO) << "Replaying record with seq: " << record->seq(); + switch (record->payload_case()) { + case WALRecord::kEntry: { + LogEntry logEntry; + logEntry.entry = record->entry(); + LOG(INFO) << "Adding entry from term: " << logEntry.entry.term(); + raft.AddToLog(logEntry, /*writeMetadata=*/false); + break; + } + case WALRecord::kTruncation: + raft.TruncateLog(record->truncation().truncate_from_index(), + /*writeMetadata=*/false); + break; + case WALRecord::PAYLOAD_NOT_SET: + FAIL() << "Unexpected PAYLOAD_NOT_SET record"; + break; + } + }, + /*set_start_point=*/[](int) {}); +} + +} // namespace + +class RaftRecoveryIntegrationTest : public ::testing::Test { + private: + class MockCommitFunction { + public: + MOCK_METHOD(int, Commit, (const google::protobuf::Message&)); + }; + + protected: + void SetUp() override { + std::filesystem::remove_all(std::filesystem::path(kLogPath).parent_path()); + } + + ResDBConfig config_ = MakeConfig(); + MockCheckPoint checkpoint_; + MockSendMessageFunction mock_call; + MockBroadcastFunction mock_broadcast; + MockCommitFunction mock_commit; +}; + +// Test 1: Restore basic metadata and log entries. +TEST_F(RaftRecoveryIntegrationTest, RaftStateRestoredAfterRecovery) { + { + RaftRecovery recovery(config_, nullptr, nullptr, nullptr); + + recovery.WriteMetadata(/*current_term=*/5, /*voted_for=*/2, + /*snapshot_last_index=*/0, /*snapshot_last_term=*/0); + + for (int i = 1; i <= 3; ++i) { + Entry e; + e.set_term(i); + ClientTestRequest req; + req.set_value("cmd-" + std::to_string(i)); + req.SerializeToString(e.mutable_command()); + recovery.AddLogEntry(&e, i); + } + } + + MockSignatureVerifier verifier; + ResDBConfig config = MakeConfig(); + MockLeaderElectionManager lem(config); + MockReplicaCommunicator comm; + MockCheckPoint ckpt; + + RaftRecovery recovery(config_, nullptr, nullptr, nullptr); + + Raft raft(/*id=*/1, /*f=*/1, /*total=*/4, &verifier, &lem, &comm, &recovery); + + RecoverFromLogs(recovery, raft); + + // --- Assertions --- + EXPECT_EQ(raft.GetCurrentTerm(), 5u); + EXPECT_EQ(raft.GetVotedFor(), 2); + EXPECT_EQ(raft.GetSnapshotLastIndex(), 0u); + + // Log: index 0 is the sentinel (term=0), indices 1–3 are the replayed + // entries. + ASSERT_EQ(raft.GetLogSize(), 4u); + for (int i = 1; i <= 3; ++i) { + const auto& le = raft.GetLog()[i]; + EXPECT_EQ(le.entry.term(), i); + ClientTestRequest req; + req.ParseFromString(le.entry.command()); + EXPECT_EQ(req.value(), "cmd-" + std::to_string(i)); + } +} + +// Test 2: Restore the log using a checkpoint and the Recovery WAL. +TEST_F(RaftRecoveryIntegrationTest, + RaftStateRestoredAfterRecoveryWithCheckpoint) { + EXPECT_CALL(mock_commit, Commit(_)).Times(2); + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& aer = dynamic_cast(msg); + EXPECT_TRUE(aer.success()); + EXPECT_EQ(aer.lastlogindex(), 13); + return 0; + })); + + { + std::promise insert_done, ckpt_fired; + auto insert_done_future = insert_done.get_future(); + auto ckpt_fired_future = ckpt_fired.get_future(); + + int call_count = 0; + EXPECT_CALL(checkpoint_, GetStableCheckpoint()) + .WillRepeatedly(Invoke([&]() -> uint64_t { + ++call_count; + if (call_count == 1) + insert_done_future.get(); + else if (call_count == 2) + ckpt_fired.set_value(true); + return 5; + })); + + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + recovery.WriteMetadata(/*current_term=*/5, /*voted_for=*/2, + /*snapshot_last_index=*/5, /*snapshot_last_term=*/5); + + for (int i = 1; i <= 8; ++i) { + Entry e; + e.set_term(i); + ClientTestRequest req; + req.set_value("Transaction " + std::to_string(i)); + req.SerializeToString(e.mutable_command()); + recovery.AddLogEntry(&e, i); + } + + insert_done.set_value(true); + ckpt_fired_future.get(); + + for (int i = 9; i <= 10; ++i) { + Entry e; + e.set_term(i); + ClientTestRequest req; + req.set_value("Transaction " + std::to_string(i)); + req.SerializeToString(e.mutable_command()); + recovery.AddLogEntry(&e, i); + } + } + + MockSignatureVerifier verifier; + ResDBConfig config = MakeConfig(); + MockLeaderElectionManager lem(config); + MockReplicaCommunicator comm; + MockCheckPoint ckpt; + + RaftRecovery recovery(config_, nullptr, nullptr, nullptr); + + Raft raft(/*id=*/1, /*f=*/1, /*total=*/4, &verifier, &lem, &comm, &recovery); + + raft.SetCommitFunc([&](const google::protobuf::Message& msg) { + return mock_commit.Commit(msg); + }); + raft.SetSingleCallFunc( + [&](int type, const google::protobuf::Message& msg, int node_id) { + return mock_call.Call(type, msg, node_id); + }); + + RecoverFromLogs(recovery, raft); + + EXPECT_EQ(raft.GetCurrentTerm(), 5u); + EXPECT_EQ(raft.GetVotedFor(), 2); + EXPECT_EQ(raft.GetSnapshotLastIndex(), 5u); + + auto aefields = CreateAeFields( + /*term=*/11, + /*leaderId=*/2, + /*prevLogIndex=*/10, + /*prevLogTerm=*/10, + /*entries=*/ + CreateLogEntries({ + {11, "Transaction 11"}, + {11, "Transaction 12"}, + {11, "Transaction 13"}, + }), + /*leaderCommit=*/7, + /*followerId=*/1); + + auto aemessage = CreateAeMessage(aefields); + + bool success = raft.ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); + + EXPECT_EQ(raft.GetCurrentTerm(), 11u); + // --- Assertions --- + EXPECT_EQ(raft.GetVotedFor(), -1); + EXPECT_EQ(raft.GetSnapshotLastIndex(), 5u); + EXPECT_EQ(raft.GetLastLogIndex(), 13u); + + // Log: index 0 is the sentinel (term/index=5), indices 1–8 are the replayed + // entries. + ASSERT_EQ(raft.GetLogSize(), 9u); + const auto& le = raft.GetLog()[0]; + EXPECT_EQ(le.entry.term(), 5); + EXPECT_EQ(raft.GetLogTermAtIndex(5), 5); + + for (int i = 1; i < 8; ++i) { + const auto& le = raft.GetLog()[i]; + if (i <= 5) { + EXPECT_EQ(le.entry.term(), i + 5); + } else { + EXPECT_EQ(le.entry.term(), 11); + } + EXPECT_EQ(raft.GetLogTermAtIndex(i + 5), le.entry.term()); + ClientTestRequest req; + req.ParseFromString(le.entry.command()); + EXPECT_EQ(req.value(), "Transaction " + std::to_string(i + 5)); + ClientTestRequest req2; + auto log_entry = raft.GetLogEntryAtIndex(i + 5); + req2.ParseFromString(log_entry.entry.command()); + EXPECT_EQ(req.value(), req2.value()); + } +} + +// Test 3: Demotion (higher-term AppendEntries) triggers WriteMetadata, and the +// updated metadata is visible after recovery. +TEST_F(RaftRecoveryIntegrationTest, DemotionTriggersWriteMetadata) { + { + MockSignatureVerifier verifier; + ResDBConfig config = MakeConfig(); + MockLeaderElectionManager lem(config); + MockReplicaCommunicator comm; + + RaftRecovery recovery(config_, nullptr, nullptr, nullptr); + + Raft raft(/*id=*/1, /*f=*/1, /*total=*/4, &verifier, &lem, &comm, + &recovery); + + recovery.WriteMetadata(/*current_term=*/3, /*voted_for=*/1, + /*snapshot_last_index=*/0, + /*snapshot_last_term=*/0); + + // Add a couple of entries so the log is non-trivial. + for (int i = 1; i <= 2; ++i) { + Entry e; + e.set_term(3); + ClientTestRequest req; + req.set_value("cmd-" + std::to_string(i)); + req.SerializeToString(e.mutable_command()); + recovery.AddLogEntry(&e, i); + } + + raft.SetStateForTest({ + .currentTerm = 6, + .role = Role::LEADER, + .log = CreateLogEntries({}, true), + }); + + raft.SetSingleCallFunc( + [&](int type, const google::protobuf::Message& msg, int node_id) { + return mock_call.Call(type, msg, node_id); + }); + + // Receive an AppendEntries from node 2 at a higher term. + auto aefields = CreateAeFields( + /*term=*/7, + /*leaderId=*/2, + /*prevLogIndex=*/0, + /*prevLogTerm=*/0, + /*entries=*/{}, + /*leaderCommit=*/0, + /*followerId=*/1); + auto aemessage = CreateAeMessage(aefields); + + raft.PrintDebugState(); + bool success = raft.ReceiveAppendEntries( + std::make_unique(std::move(aemessage))); + EXPECT_TRUE(success); + + EXPECT_EQ(raft.GetCurrentTerm(), 7u); + EXPECT_EQ(raft.GetVotedFor(), -1); + } + + { + MockSignatureVerifier verifier; + ResDBConfig config = MakeConfig(); + MockLeaderElectionManager lem(config); + MockReplicaCommunicator comm; + + RaftRecovery recovery(config_, nullptr, nullptr, nullptr); + + Raft raft(/*id=*/1, /*f=*/1, /*total=*/4, &verifier, &lem, &comm, + &recovery); + + RecoverFromLogs(recovery, raft); + + EXPECT_EQ(raft.GetCurrentTerm(), 7u); + EXPECT_EQ(raft.GetVotedFor(), -1); + + // The two entries written before the demotion should still be present. + ASSERT_EQ(raft.GetLogSize(), 3u); + for (int i = 1; i <= 2; ++i) { + const auto& le = raft.GetLog()[i]; + EXPECT_EQ(le.entry.term(), 3); + ClientTestRequest req; + req.ParseFromString(le.entry.command()); + EXPECT_EQ(req.value(), "cmd-" + std::to_string(i)); + } + } +} + +// Test 4: A truncation that occurs after a checkpoint is replayed correctly. +TEST_F(RaftRecoveryIntegrationTest, TruncationPersistsAfterCheckpoint) { + // Timeline: + // - Write entries 1–5 (all term 3). + // - Checkpoint fires at stable index 2 → WAL is compacted up to index 2. + // - Truncate from index 4 onward. + // - Write new entries at index 4–5 with term 6 and different commands. + { + std::promise insert_done, ckpt_fired; + auto insert_done_future = insert_done.get_future(); + auto ckpt_fired_future = ckpt_fired.get_future(); + + int call_count = 0; + EXPECT_CALL(checkpoint_, GetStableCheckpoint()) + .WillRepeatedly(Invoke([&]() -> uint64_t { + ++call_count; + if (call_count == 1) + insert_done_future.get(); // block until initial entries are in + else if (call_count == 2) + ckpt_fired.set_value(true); // signal that the checkpoint fired + return 2; // checkpoint covers indices 1–2 + })); + + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + // Seed metadata. + recovery.WriteMetadata(/*current_term=*/3, /*voted_for=*/1, + /*snapshot_last_index=*/0, + /*snapshot_last_term=*/0); + + // Write entries 1–5 at term 3. + for (int i = 1; i <= 5; ++i) { + Entry e; + e.set_term(3); + ClientTestRequest req; + req.set_value("original-" + std::to_string(i)); + req.SerializeToString(e.mutable_command()); + recovery.AddLogEntry(&e, i); + } + + // Unblock the checkpoint poll and wait for it to fire. + insert_done.set_value(true); + ckpt_fired_future.get(); + + recovery.WriteMetadata(/*current_term=*/3, /*voted_for=*/1, + /*snapshot_last_index=*/2, + /*snapshot_last_term=*/3); + + // Truncate from index 4 onward (entries 4 and 5 are discarded). + // NOTE: Assumes RaftRecovery::TruncateLog(from_index) writes a + // kTruncation WAL record. Adjust the call if the API differs. + TruncationRecord truncation; + truncation.set_truncate_from_index(4); + truncation.set_truncate_from_term(3); + recovery.TruncateLog(truncation); + + // Rewrite indices 4–5 under term 6 with different commands. + for (int i = 4; i <= 5; ++i) { + Entry e; + e.set_term(6); + ClientTestRequest req; + req.set_value("rewritten-" + std::to_string(i)); + req.SerializeToString(e.mutable_command()); + recovery.AddLogEntry(&e, i); + } + } + + { + MockSignatureVerifier verifier; + ResDBConfig config = MakeConfig(); + MockLeaderElectionManager lem(config); + MockReplicaCommunicator comm; + + RaftRecovery recovery(config_, nullptr, nullptr, nullptr); + + Raft raft(/*id=*/1, /*f=*/1, /*total=*/4, &verifier, &lem, &comm, + &recovery); + + // Recover and verify: indices 1–3 are untouched, 4–5 carry the new data. + RecoverFromLogs(recovery, raft); + + EXPECT_EQ(raft.GetSnapshotLastIndex(), 2u); + + // Sentinel (index 0) + entries 1–5 after truncation/rewrite = 6 total. + // The WAL after compaction starts from the checkpoint (index 2 sentinel), + // then replays entries 3, 4 (rewritten), 5 (rewritten). + EXPECT_EQ(raft.GetLogSize(), 4u); + + // Entry at absolute index 3 should be original. + { + auto le = raft.GetLogEntryAtIndex(3); + EXPECT_EQ(le.entry.term(), 3); + ClientTestRequest req; + req.ParseFromString(le.entry.command()); + EXPECT_EQ(req.value(), "original-3"); + } + + // Entries at absolute indices 4–5 must reflect the post-truncation rewrite. + for (int i = 4; i <= 5; ++i) { + auto le = raft.GetLogEntryAtIndex(i); + EXPECT_EQ(le.entry.term(), 6); + ClientTestRequest req; + req.ParseFromString(le.entry.command()); + EXPECT_EQ(req.value(), "rewritten-" + std::to_string(i)); + } + + EXPECT_EQ(raft.GetLogTermAtIndex(4), 6); + EXPECT_EQ(raft.GetLogTermAtIndex(5), 6); + } +} + +} // namespace raft +} // namespace resdb \ No newline at end of file diff --git a/platform/consensus/ordering/raft/algorithm/raft_request_vote_response_test.cpp b/platform/consensus/ordering/raft/algorithm/raft_request_vote_response_test.cpp new file mode 100644 index 0000000000..08cd6c456c --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/raft_request_vote_response_test.cpp @@ -0,0 +1,204 @@ +#include "platform/consensus/ordering/raft/algorithm/raft_tests.h" + +namespace resdb { +namespace raft { + +// Test 1: A candidate gets elected. +TEST_F(RaftTest, CandidateGetsElected) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(1); + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& AppendEntriesMessage = + dynamic_cast(msg); + EXPECT_EQ(node_id, 2); + EXPECT_EQ(AppendEntriesMessage.entries_size(), 0); + EXPECT_EQ(AppendEntriesMessage.prevlogterm(), 1); + EXPECT_EQ(AppendEntriesMessage.prevlogindex(), 2); + EXPECT_EQ(AppendEntriesMessage.leaderid(), 1); + EXPECT_EQ(AppendEntriesMessage.leadercommitindex(), 1); + return 0; + })) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& AppendEntriesMessage = + dynamic_cast(msg); + EXPECT_EQ(node_id, 3); + EXPECT_EQ(AppendEntriesMessage.entries_size(), 0); + EXPECT_EQ(AppendEntriesMessage.prevlogterm(), 1); + EXPECT_EQ(AppendEntriesMessage.prevlogindex(), 2); + EXPECT_EQ(AppendEntriesMessage.leaderid(), 1); + EXPECT_EQ(AppendEntriesMessage.leadercommitindex(), 1); + return 0; + })) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& AppendEntriesMessage = + dynamic_cast(msg); + EXPECT_EQ(node_id, 4); + EXPECT_EQ(AppendEntriesMessage.entries_size(), 0); + EXPECT_EQ(AppendEntriesMessage.prevlogterm(), 1); + EXPECT_EQ(AppendEntriesMessage.prevlogindex(), 2); + EXPECT_EQ(AppendEntriesMessage.leaderid(), 1); + EXPECT_EQ(AppendEntriesMessage.leadercommitindex(), 1); + return 0; + })); + + raft_->SetStateForTest({.currentTerm = 2, + .commitIndex = 1, + .lastCommitted = 1, + .role = Role::CANDIDATE, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + {1, "Term 1 Transaction 1"}, + }, + true), + .votes = std::vector{1, 3}}); + + RequestVoteResponse rvr; + rvr.set_term(2); + rvr.set_voterid(2); + rvr.set_votegranted(true); + raft_->ReceiveRequestVoteResponse(std::make_unique(rvr)); + + EXPECT_EQ(raft_->GetCurrentTerm(), 2); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::LEADER); + EXPECT_EQ(raft_->GetLastLogIndexFromLog(), 2); + EXPECT_THAT(raft_->GetNextIndex(), ::testing::ElementsAre(3, 3, 3, 3, 3)); + EXPECT_THAT(raft_->GetMatchIndex(), ::testing::ElementsAre(0, 2, 0, 0, 0)); +} + +// Test 2: A candidate receives a RequestVoteResponse from an older term and +// ignores it. +TEST_F(RaftTest, CandidateIgnoresResponseFromOldTerm) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)).Times(0); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + raft_->SetStateForTest({ + .currentTerm = 2, + .role = Role::CANDIDATE, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + RequestVoteResponse rvr; + rvr.set_term(1); + rvr.set_voterid(2); + rvr.set_votegranted(true); + raft_->ReceiveRequestVoteResponse(std::make_unique(rvr)); + + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::CANDIDATE); +} + +// Test 3: A candidate receives a RequestVoteResponse from an newer term and +// demotes. +TEST_F(RaftTest, CandidateDemotesAfterRequestVoteResponseFromNewerTerm) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(1); + EXPECT_CALL(mock_call, Call(_, _, _)).Times(0); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + raft_->SetStateForTest({ + .currentTerm = 2, + .role = Role::CANDIDATE, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + RequestVoteResponse rvr; + rvr.set_term(3); + rvr.set_voterid(2); + rvr.set_votegranted(false); + raft_->ReceiveRequestVoteResponse(std::make_unique(rvr)); + + EXPECT_EQ(raft_->GetVotedFor(), -1); + EXPECT_EQ(raft_->GetCurrentTerm(), 3); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::FOLLOWER); +} + +// Test 4: A follower ignores a RequestVoteResponse. +TEST_F(RaftTest, FollowerIgnoresRequestVoteResponse) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)).Times(0); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + raft_->SetStateForTest({ + .currentTerm = 2, + .role = Role::FOLLOWER, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + RequestVoteResponse rvr; + rvr.set_term(2); + rvr.set_voterid(2); + rvr.set_votegranted(true); + raft_->ReceiveRequestVoteResponse(std::make_unique(rvr)); + + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::FOLLOWER); +} + +// Test 5: A candidate ignores a no vote in a RequestVoteResponse. +TEST_F(RaftTest, CandidateIgnoresNoVote) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)).Times(0); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + raft_->SetStateForTest({ + .currentTerm = 2, + .role = Role::CANDIDATE, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + RequestVoteResponse rvr; + rvr.set_term(2); + rvr.set_voterid(2); + rvr.set_votegranted(false); + raft_->ReceiveRequestVoteResponse(std::make_unique(rvr)); + + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::CANDIDATE); +} + +// Test 6: A candidate ignores a duplicate vote. +TEST_F(RaftTest, CandidateIgnoresDuplicateVote) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)).Times(0); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + raft_->SetStateForTest({.currentTerm = 2, + .commitIndex = 1, + .lastCommitted = 1, + .role = Role::CANDIDATE, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + {1, "Term 1 Transaction 1"}, + }, + true), + .votes = std::vector{1, 2}}); + + RequestVoteResponse rvr; + rvr.set_term(2); + rvr.set_voterid(2); + rvr.set_votegranted(true); + raft_->ReceiveRequestVoteResponse(std::make_unique(rvr)); + + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::CANDIDATE); +} + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/raft_request_vote_test.cpp b/platform/consensus/ordering/raft/algorithm/raft_request_vote_test.cpp new file mode 100644 index 0000000000..156a8ab138 --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/raft_request_vote_test.cpp @@ -0,0 +1,326 @@ +#include "platform/consensus/ordering/raft/algorithm/raft_tests.h" + +namespace resdb { +namespace raft { + +// Test 1: A follower times out, transitions to candidate, and starts an +// election. +TEST_F(RaftTest, FollowerTransitionsToCandidateAndStartsElection) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(1); + EXPECT_CALL(mock_broadcast, Broadcast(_, _)) + .WillOnce( + ::testing::Invoke([](int type, const google::protobuf::Message& msg) { + const auto& requestVote = dynamic_cast(msg); + EXPECT_EQ(requestVote.term(), 1); + EXPECT_EQ(requestVote.candidateid(), 1); + EXPECT_EQ(requestVote.lastlogindex(), 1); + EXPECT_EQ(requestVote.lastlogterm(), 0); + return 0; + })); + + raft_->SetStateForTest({ + .currentTerm = 0, + .role = Role::FOLLOWER, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + raft_->StartElection(); + EXPECT_EQ(raft_->GetVotedFor(), 1); + EXPECT_EQ(raft_->GetCurrentTerm(), 1); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::CANDIDATE); +} + +// Test 2: A leader receives a RequestVote from a candidate in a newer term and +// demotes. +TEST_F(RaftTest, LeaderReceivesRequestVoteFromNewTermAndDemotes) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(1); + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& requestVoteResponse = + dynamic_cast(msg); + EXPECT_EQ(node_id, 2); + EXPECT_EQ(requestVoteResponse.term(), 1); + EXPECT_EQ(requestVoteResponse.voterid(), 1); + EXPECT_TRUE(requestVoteResponse.votegranted()); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(1); + + RequestVote rv; + rv.set_term(1); + rv.set_candidateid(2); + rv.set_lastlogindex(1); + rv.set_lastlogterm(0); + + raft_->SetStateForTest({ + .currentTerm = 0, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + raft_->ReceiveRequestVote(std::make_unique(rv)); + + EXPECT_EQ(raft_->GetVotedFor(), 2); + EXPECT_EQ(raft_->GetCurrentTerm(), 1); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::FOLLOWER); +} + +// Test 3: A leader receives a RequestVote from a candidate whose lastLogTerm is +// fewer and does not vote. +TEST_F(RaftTest, LeaderReceivesRequestVoteFromOldTerm) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& requestVoteResponse = + dynamic_cast(msg); + EXPECT_EQ(node_id, 2); + EXPECT_EQ(requestVoteResponse.term(), 1); + EXPECT_EQ(requestVoteResponse.voterid(), 1); + EXPECT_FALSE(requestVoteResponse.votegranted()); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + RequestVote rv; + rv.set_term(1); + rv.set_candidateid(2); + rv.set_lastlogindex(0); + rv.set_lastlogterm(0); + + raft_->SetStateForTest({ + .currentTerm = 1, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + raft_->ReceiveRequestVote(std::make_unique(rv)); + + EXPECT_EQ(raft_->GetVotedFor(), -1); + EXPECT_EQ(raft_->GetCurrentTerm(), 1); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::LEADER); +} + +// Test 4: A leader receives a RequestVote from a candidate whose lastLogTerm is +// less recent. +TEST_F(RaftTest, LeaderReceivesRequestVoteFromOlderLastLogTerm) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& requestVoteResponse = + dynamic_cast(msg); + EXPECT_EQ(node_id, 2); + EXPECT_EQ(requestVoteResponse.term(), 1); + EXPECT_EQ(requestVoteResponse.voterid(), 1); + EXPECT_FALSE(requestVoteResponse.votegranted()); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + RequestVote rv; + rv.set_term(1); + rv.set_candidateid(2); + rv.set_lastlogindex(0); + rv.set_lastlogterm(0); + + raft_->SetStateForTest({ + .currentTerm = 1, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + raft_->ReceiveRequestVote(std::make_unique(rv)); + + EXPECT_EQ(raft_->GetVotedFor(), -1); + EXPECT_EQ(raft_->GetCurrentTerm(), 1); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::LEADER); +} + +// Test 5: A leader receives a RequestVote from a candidate whose lastLogTerm is +// the same, but whose lastLogIndex is further behind. +TEST_F(RaftTest, LeaderReceivesRequestVoteFromFurtherBehindLog) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(1); + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& requestVoteResponse = + dynamic_cast(msg); + EXPECT_EQ(node_id, 2); + EXPECT_EQ(requestVoteResponse.term(), 2); + EXPECT_EQ(requestVoteResponse.voterid(), 1); + EXPECT_FALSE(requestVoteResponse.votegranted()); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + RequestVote rv; + rv.set_term(2); + rv.set_candidateid(2); + rv.set_lastlogindex(0); + rv.set_lastlogterm(0); + + raft_->SetStateForTest({ + .currentTerm = 1, + .role = Role::LEADER, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + raft_->ReceiveRequestVote(std::make_unique(rv)); + + EXPECT_EQ(raft_->GetVotedFor(), -1); + EXPECT_EQ(raft_->GetCurrentTerm(), 2); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::FOLLOWER); +} + +// Test 6: A follower receives a RequestVote from a candidate who it would vote +// for, if it had not already voted for someone else. +TEST_F(RaftTest, FollowerRejectsRequestVoteBecauseAlreadyVoted) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& requestVoteResponse = + dynamic_cast(msg); + EXPECT_EQ(node_id, 2); + EXPECT_EQ(requestVoteResponse.term(), 2); + EXPECT_EQ(requestVoteResponse.voterid(), 1); + EXPECT_FALSE(requestVoteResponse.votegranted()); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + RequestVote rv; + rv.set_term(2); + rv.set_candidateid(2); + rv.set_lastlogindex(2); + rv.set_lastlogterm(1); + + raft_->SetStateForTest({ + .currentTerm = 2, + .votedFor = 3, + .role = Role::FOLLOWER, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + raft_->ReceiveRequestVote(std::make_unique(rv)); + + EXPECT_EQ(raft_->GetVotedFor(), 3); + EXPECT_EQ(raft_->GetCurrentTerm(), 2); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::FOLLOWER); +} + +// Test 7: A follower times out and starts an election. Then, as a candidate +// times out and starts another election. +TEST_F(RaftTest, CandidateTimesOutAndStartsAnotherElection) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(1); + EXPECT_CALL(mock_broadcast, Broadcast(_, _)) + .WillOnce( + ::testing::Invoke([](int type, const google::protobuf::Message& msg) { + const auto& requestVote = dynamic_cast(msg); + EXPECT_EQ(requestVote.term(), 1); + EXPECT_EQ(requestVote.candidateid(), 1); + EXPECT_EQ(requestVote.lastlogindex(), 1); + EXPECT_EQ(requestVote.lastlogterm(), 0); + return 0; + })) + .WillOnce( + ::testing::Invoke([](int type, const google::protobuf::Message& msg) { + const auto& requestVote = dynamic_cast(msg); + EXPECT_EQ(requestVote.term(), 2); + EXPECT_EQ(requestVote.candidateid(), 1); + EXPECT_EQ(requestVote.lastlogindex(), 1); + EXPECT_EQ(requestVote.lastlogterm(), 0); + return 0; + })); + + raft_->SetStateForTest({ + .currentTerm = 0, + .role = Role::FOLLOWER, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + raft_->StartElection(); + EXPECT_EQ(raft_->GetVotedFor(), 1); + EXPECT_EQ(raft_->GetCurrentTerm(), 1); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::CANDIDATE); + + // Start another election after a timeout + raft_->StartElection(); + EXPECT_EQ(raft_->GetVotedFor(), 1); + EXPECT_EQ(raft_->GetCurrentTerm(), 2); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::CANDIDATE); +} + +// Test 8: A candidate receives a RequestVote from another candidate in the same +// term and does not demote. +TEST_F(RaftTest, CandidateReceivesRequestVoteFromSameTermAndDoesNotDemote) { + EXPECT_CALL(*leader_election_manager_, OnRoleChange()).Times(0); + EXPECT_CALL(mock_call, Call(_, _, _)) + .WillOnce(::testing::Invoke( + [](int type, const google::protobuf::Message& msg, int node_id) { + const auto& requestVoteResponse = + dynamic_cast(msg); + EXPECT_EQ(node_id, 2); + EXPECT_EQ(requestVoteResponse.term(), 1); + EXPECT_EQ(requestVoteResponse.voterid(), 1); + EXPECT_FALSE(requestVoteResponse.votegranted()); + return 0; + })); + EXPECT_CALL(*leader_election_manager_, OnHeartBeat()).Times(0); + + RequestVote rv; + rv.set_term(1); + rv.set_candidateid(2); + rv.set_lastlogindex(1); + rv.set_lastlogterm(0); + + raft_->SetStateForTest({ + .currentTerm = 1, + .votedFor = 1, + .role = Role::CANDIDATE, + .log = CreateLogEntries( + { + {0, "Term 0 Transaction 1"}, + }, + true), + }); + + raft_->ReceiveRequestVote(std::make_unique(rv)); + + EXPECT_EQ(raft_->GetVotedFor(), 1); + EXPECT_EQ(raft_->GetCurrentTerm(), 1); + EXPECT_EQ(raft_->GetRoleSnapshot(), Role::CANDIDATE); +} + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/raft_test_util.h b/platform/consensus/ordering/raft/algorithm/raft_test_util.h new file mode 100644 index 0000000000..b468edc230 --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/raft_test_util.h @@ -0,0 +1,118 @@ +#pragma once + +#include + +#include "common/crypto/mock_signature_verifier.h" +#include "platform/config/resdb_config_utils.h" +#include "platform/consensus/ordering/raft/algorithm/mock_leader_election_manager.h" +#include "platform/consensus/ordering/raft/algorithm/raft.h" +#include "platform/networkstrate/mock_replica_communicator.h" +#include "platform/proto/client_test.pb.h" + +namespace resdb { +namespace raft { +namespace test_utils { + +inline ResDBConfig GenerateConfig() { + ResConfigData data; + data.set_duplicate_check_frequency_useconds(100000); + data.set_enable_viewchange(true); + return ResDBConfig({GenerateReplicaInfo(1, "127.0.0.1", 1234), + GenerateReplicaInfo(2, "127.0.0.1", 1235), + GenerateReplicaInfo(3, "127.0.0.1", 1236), + GenerateReplicaInfo(4, "127.0.0.1", 1237)}, + GenerateReplicaInfo(1, "127.0.0.1", 1234), data); +} + +class MockSendMessageFunction { + public: + MOCK_METHOD(int, Call, (int, const google::protobuf::Message&, int)); +}; +class MockBroadcastFunction { + public: + MOCK_METHOD(int, Broadcast, (int, const google::protobuf::Message&)); +}; +class MockCommitFunction { + public: + MOCK_METHOD(int, Commit, (const google::protobuf::Message&)); +}; + +inline AeFields CreateAeFields(uint64_t term, int leaderId, + uint64_t prevLogIndex, uint64_t prevLogTerm, + const std::vector& entries, + uint64_t leaderCommit, int followerId) { + AeFields fields{}; + fields.term = term; + fields.leaderId = leaderId; + fields.leaderCommit = leaderCommit; + fields.prevLogIndex = prevLogIndex; + fields.prevLogTerm = prevLogTerm; + fields.followerId = followerId; + + for (const auto& logEntry : entries) { + LogEntry log_entry; + log_entry.entry.set_term(logEntry.entry.term()); + log_entry.entry.set_command(logEntry.entry.command()); + fields.entries.push_back(std::move(log_entry)); + } + + return fields; +}; + +// Helper to create a single log entry. +inline LogEntry CreateLogEntry(uint64_t term, const std::string& command_data) { + LogEntry log_entry; + log_entry.entry.set_term(term); + log_entry.entry.set_command(command_data); + return log_entry; +} + +// Helper to create a vector of log entries for testing. +inline std::vector CreateLogEntries( + const std::vector>& term_and_cmds, + bool usedForLogPatch = false) { + std::vector entries; + + if (usedForLogPatch) { + LogEntry first_entry; + first_entry.entry.set_term(0); + first_entry.entry.set_command("COMMON_PREFIX"); + entries.push_back(first_entry); + } + + for (const auto& [term, cmd] : term_and_cmds) { + LogEntry log_entry; + log_entry.entry.set_term(term); + + ClientTestRequest req; + req.set_value(cmd); + + std::string serialized; + req.SerializeToString(&serialized); + log_entry.entry.set_command(serialized); + + entries.push_back(log_entry); + } + + return entries; +} + +inline AppendEntries CreateAeMessage(const AeFields& fields) { + AppendEntries ae; + ae.set_term(fields.term); + ae.set_leaderid(fields.leaderId); + ae.set_prevlogindex(fields.prevLogIndex); + ae.set_prevlogterm(fields.prevLogTerm); + ae.set_leadercommitindex(fields.leaderCommit); + for (const auto& log_entry : fields.entries) { + auto* newEntry = ae.add_entries(); + newEntry->set_term(log_entry.entry.term()); + newEntry->set_command(log_entry.entry.command()); + } + + return ae; +} + +} // namespace test_utils +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/algorithm/raft_tests.h b/platform/consensus/ordering/raft/algorithm/raft_tests.h new file mode 100644 index 0000000000..5ead2ec963 --- /dev/null +++ b/platform/consensus/ordering/raft/algorithm/raft_tests.h @@ -0,0 +1,60 @@ +#include "platform/consensus/ordering/raft/algorithm/raft_test_util.h" +#include "platform/consensus/recovery/mock_raft_recovery.h" + +namespace resdb { +namespace raft { +using resdb::raft::test_utils::CreateAeFields; +using resdb::raft::test_utils::CreateAeMessage; +using resdb::raft::test_utils::CreateLogEntries; +using resdb::raft::test_utils::GenerateConfig; +using resdb::raft::test_utils::MockBroadcastFunction; +using resdb::raft::test_utils::MockCommitFunction; +using resdb::raft::test_utils::MockSendMessageFunction; +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::Invoke; +using ::testing::Matcher; + +class RaftTest : public ::testing::Test { + + protected: + void SetUp() override { + verifier_ = std::make_unique(); + ResDBConfig config_ = GenerateConfig(); + leader_election_manager_ = + std::make_unique(config_); + replica_communicator_ = std::make_unique(); + recovery_ = std::make_unique(config_); + raft_ = std::make_unique( + /*id=*/1, + /*f=*/1, + /*total=*/4, verifier_.get(), leader_election_manager_.get(), + replica_communicator_.get(), recovery_.get()); + + raft_->SetSingleCallFunc( + [&](int type, const google::protobuf::Message& msg, int node_id) { + return mock_call.Call(type, msg, node_id); + }); + + raft_->SetBroadcastCallFunc( + [&](int type, const google::protobuf::Message& msg) { + return mock_broadcast.Broadcast(type, msg); + }); + + raft_->SetCommitFunc([&](const google::protobuf::Message& msg) { + return mock_commit.Commit(msg); + }); + } + + std::unique_ptr verifier_; + std::unique_ptr leader_election_manager_; + std::unique_ptr replica_communicator_; + std::unique_ptr recovery_; + std::unique_ptr raft_; + MockSendMessageFunction mock_call; + MockBroadcastFunction mock_broadcast; + MockCommitFunction mock_commit; +}; + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/framework/BUILD b/platform/consensus/ordering/raft/framework/BUILD new file mode 100644 index 0000000000..6c70d57835 --- /dev/null +++ b/platform/consensus/ordering/raft/framework/BUILD @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +package(default_visibility = ["//platform/consensus/ordering/raft:__subpackages__"]) + + +cc_library( + name = "raft_checkpoint_manager", + hdrs = ["raft_checkpoint_manager.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//platform/consensus/checkpoint:checkpoint" + ], +) + +cc_library( + name = "consensus", + srcs = ["consensus.cpp"], + hdrs = ["consensus.h"], + visibility = [ + "//visibility:public", + ], + deps = [ + ":raft_checkpoint_manager", + "//platform/consensus/recovery:raft_recovery", + "//common/utils", + "//platform/consensus/ordering/common/framework:consensus", + "//platform/consensus/ordering/raft/algorithm:raft", + ], +) diff --git a/platform/consensus/ordering/raft/framework/consensus.cpp b/platform/consensus/ordering/raft/framework/consensus.cpp new file mode 100644 index 0000000000..c6c01117cb --- /dev/null +++ b/platform/consensus/ordering/raft/framework/consensus.cpp @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "platform/consensus/ordering/raft/framework/consensus.h" + +#include +#include + +#include "common/utils/utils.h" +#include "platform/consensus/ordering/raft/proto/proposal.pb.h" +#include "platform/proto/resdb.pb.h" + +namespace resdb { +namespace raft { + +Consensus::Consensus(const ResDBConfig& config, + std::unique_ptr executor) + : common::Consensus(config, std::move(executor)), + leader_election_manager_( + std::make_unique(config_)), + system_info_(std::make_unique(config_)), + raft_checkpoint_manager_(std::make_unique()), + recovery_(std::make_unique( + config_, raft_checkpoint_manager_.get(), + transaction_executor_->GetStorage(), + [this](uint64_t seq) { OnCheckpointFinish(seq); })) { + //LOG(INFO) << "JIM -> " << __FUNCTION__ << ": In consensus constructor"; + int total_replicas = config_.GetReplicaNum(); + int f = (total_replicas - 1) / 3; + + Init(); + + if (config_.GetPublicKeyCertificateInfo() + .public_key() + .public_key_info() + .type() != CertificateKeyInfo::CLIENT) { + raft_ = std::make_unique(config_.GetSelfInfo().id(), f, total_replicas, + GetSignatureVerifier(), leader_election_manager_.get(), + replica_communicator_, recovery_.get()); + + leader_election_manager_->SetRaft(raft_.get()); + leader_election_manager_->MayStart(); + + RecoverFromLogs(); + + InitProtocol(raft_.get()); + } +} + +int Consensus::ProcessCustomConsensus(std::unique_ptr request) { + if (request->user_type() == MessageType::AppendEntriesMsg) { + //LOG(ERROR) << "Received AppendEntriesMsg"; + std::unique_ptr txn = std::make_unique(); + if (!txn->ParseFromString(request->data())) { + LOG(ERROR) << "parse proposal fail"; + assert(1 == 0); + return -1; + } + raft_->ReceiveAppendEntries(std::move(txn)); + return 0; + } + else if (request->user_type() == MessageType::AppendEntriesResponseMsg) { + std::unique_ptr AppendEntriesResponse = std::make_unique(); + if (!AppendEntriesResponse->ParseFromString(request->data())) { + LOG(ERROR) << "parse proposal fail"; + assert(1 == 0); + return -1; + } + raft_->ReceiveAppendEntriesResponse(std::move(AppendEntriesResponse)); + return 0; + } + else if (request->user_type() == MessageType::RequestVoteMsg) { + std::unique_ptr rv = std::make_unique(); + if (!rv->ParseFromString(request->data())) { + LOG(ERROR) << "parse proposal fail"; + assert(1 == 0); + return -1; + } + raft_->ReceiveRequestVote(std::move(rv)); + return 0; + } + else if (request->user_type() == MessageType::RequestVoteResponseMsg) { + std::unique_ptr rvr = std::make_unique(); + if (!rvr->ParseFromString(request->data())) { + LOG(ERROR) << "parse proposal fail"; + assert(1 == 0); + return -1; + } + raft_->ReceiveRequestVoteResponse(std::move(rvr)); + return 0; + } + else if (request->user_type() == MessageType::DirectToLeaderMsg) { + //LOG(INFO) << "JIM -> " << __FUNCTION__ << ": In DirectToLeader"; + std::unique_ptr dtl = std::make_unique(); + if (!dtl->ParseFromString(request->data())) { + LOG(ERROR) << "parse proposal fail"; + assert(1 == 0); + return -1; + } + performance_manager_->SetPrimary(dtl->leaderid()); + return 0; + } + LOG(ERROR) << "Unknown message type"; + return 0; +} + +void Consensus::RecoverFromLogs() { + recovery_->ReadLogs( + [&](const RaftMetadata& metadata) { + LOG(INFO) << " read current term: " << metadata.current_term + << " voted for: " << metadata.voted_for; + raft_->SetCurrentTerm(metadata.current_term, false); + raft_->SetVotedFor(metadata.voted_for, false); + raft_->SetSnapshotLastIndexAndTerm(metadata.snapshot_last_index, + metadata.snapshot_last_term, false); + }, + [&](std::unique_ptr record) { + switch (record->payload_case()) { + case WALRecord::kEntry: { + LogEntry logEntry; + logEntry.entry = record->entry(); + raft_->AddToLog(logEntry, false); + break; + } + case WALRecord::kTruncation: + raft_->TruncateLog(record->truncation().truncate_from_index(), + false); + break; + case WALRecord::PAYLOAD_NOT_SET: + assert(false && "WALRecord does not contain Truncation or Entry"); + break; + } + }, + [](int) {}); +} + +int Consensus::ProcessNewTransaction(std::unique_ptr request) { + return raft_->ReceiveTransaction(std::move(request)); +} + +int Consensus::CommitMsg(const google::protobuf::Message& msg) { + auto* req = dynamic_cast(&msg); + if (!req) { + LOG(INFO) << "JIM -> " << __FUNCTION__ << ": Failed to cast Message to Request"; + return -1; + } + auto execReq = std::make_unique(*req); + transaction_executor_->Commit(std::move(execReq)); + return 0; +} + +int Consensus::ResponseMsg(const BatchUserResponse& batch_resp) { + // While we may receive these ResponseMsg's out of order, we do know the + // execution of transactions are guaranteed to be in order, so we know all + // transactions before batch_resp.seq() have been executed. + last_applied_ = std::max(batch_resp.seq(), last_applied_); + + // raft_checkpoint_manager_->SetStableCheckpoint(batch_resp.seq()); + if (batch_resp.seq() >= snapshot_interval_ + last_snapshot_initiated_at_) { + LOG(INFO) << "Initiating checkpoint at seq: " << batch_resp.seq(); + // Update the checkpoint in the manager + raft_checkpoint_manager_->SetStableCheckpoint(batch_resp.seq()); + last_snapshot_initiated_at_ = batch_resp.seq(); + LOG(INFO) << "Next Checkpoint will be after " + << (snapshot_interval_ + last_snapshot_initiated_at_); + } + return common::Consensus::ResponseMsg(batch_resp); +}; + +void Consensus::OnCheckpointFinish(uint64_t seq) { + LOG(INFO) << "Checkpointed all entries up to " << seq; + // raft_->TruncatePrefix(seq); +} + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/framework/consensus.h b/platform/consensus/ordering/raft/framework/consensus.h new file mode 100644 index 0000000000..68ed55c44e --- /dev/null +++ b/platform/consensus/ordering/raft/framework/consensus.h @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include "executor/common/transaction_manager.h" +#include "platform/consensus/ordering/common/framework/consensus.h" +#include "platform/consensus/ordering/raft/algorithm/leaderelection_manager.h" +#include "platform/consensus/ordering/raft/algorithm/raft.h" +#include "platform/consensus/ordering/raft/framework/raft_checkpoint_manager.h" +#include "platform/consensus/recovery/raft_recovery.h" +#include "platform/networkstrate/consensus_manager.h" + +namespace resdb { +namespace raft { + +class Consensus : public common::Consensus { + public: + Consensus(const ResDBConfig& config, + std::unique_ptr transaction_manager); + virtual ~Consensus() = default; + + private: + int ProcessCustomConsensus(std::unique_ptr request) override; + int ProcessNewTransaction(std::unique_ptr request) override; + int CommitMsg(const google::protobuf::Message& msg) override; + int CommitMsgInternal(const AppendEntries& txn); + int ResponseMsg(const BatchUserResponse& batch_resp) override; + void RecoverFromLogs(); + void OnCheckpointFinish(uint64_t seq); + + protected: + std::unique_ptr raft_; + std::unique_ptr leader_election_manager_; + std::unique_ptr system_info_; + std::unique_ptr raft_checkpoint_manager_; + std::unique_ptr recovery_; + uint32_t snapshot_interval_ = 1000; + uint64_t last_applied_ = 0; + uint32_t last_snapshot_initiated_at_ = 0; +}; + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/ordering/raft/framework/raft_checkpoint_manager.h b/platform/consensus/ordering/raft/framework/raft_checkpoint_manager.h new file mode 100644 index 0000000000..c4774a27c7 --- /dev/null +++ b/platform/consensus/ordering/raft/framework/raft_checkpoint_manager.h @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include "platform/consensus/checkpoint/checkpoint.h" + +namespace resdb { + +class RaftCheckPoint : public CheckPoint { + public: + RaftCheckPoint() = default; + virtual ~RaftCheckPoint() = default; + + virtual uint64_t GetStableCheckpoint() { return current_stable_seq_.load(); } + virtual void SetStableCheckpoint(uint64_t current_stable_seq) { + current_stable_seq_.store(current_stable_seq); + } + + private: + std::atomic current_stable_seq_ = 0; +}; + +} // namespace resdb diff --git a/platform/consensus/ordering/raft/proto/BUILD b/platform/consensus/ordering/raft/proto/BUILD new file mode 100644 index 0000000000..114b78e24b --- /dev/null +++ b/platform/consensus/ordering/raft/proto/BUILD @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +package(default_visibility = ["//platform/consensus/ordering/raft:__subpackages__"]) + +load("@rules_cc//cc:defs.bzl", "cc_proto_library") +load("@rules_proto//proto:defs.bzl", "proto_library") +load("@rules_proto_grpc//python:defs.bzl", "python_proto_library") + +proto_library( + name = "proposal_proto", + srcs = ["proposal.proto"], + #visibility = ["//visibility:public"], +) + +cc_proto_library( + name = "proposal_cc_proto", + deps = [":proposal_proto"], + visibility = ["//platform/consensus/ordering/raft:__subpackages__", + "//platform/consensus/recovery:__subpackages__"], +) diff --git a/platform/consensus/ordering/raft/proto/proposal.proto b/platform/consensus/ordering/raft/proto/proposal.proto new file mode 100644 index 0000000000..835d5132f3 --- /dev/null +++ b/platform/consensus/ordering/raft/proto/proposal.proto @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +syntax = "proto3"; + +package resdb.raft; + +message Entry { + uint64 term = 1; + bytes command = 2; +} + +message TruncationRecord { + // Including this index, remove it and everything after it. + uint64 truncate_from_index = 1; + uint64 truncate_from_term = 2; +} + +message WALRecord { + uint64 seq = 1; + oneof payload { + Entry entry = 2; + TruncationRecord truncation = 3; + } +} + +message AppendEntries{ + uint64 term = 1; + int32 leaderId = 2; + uint64 prevLogIndex = 3; + uint64 prevLogTerm = 4; + repeated Entry entries = 5; + uint64 leaderCommitIndex = 6; +} + +message AppendEntriesResponse { + uint64 term = 1; + bool success = 2; + int32 id = 3; + uint64 lastLogIndex = 5; +} + +message RequestVote { + uint64 term = 1; + int32 candidateId = 2; + uint64 lastLogIndex = 3; + uint64 lastLogTerm = 4; +} + +message RequestVoteResponse { + uint64 term = 1; + bool voteGranted = 2; + int32 voterId = 3; +} + +message DirectToLeader { + uint64 term = 1; + int32 leaderId = 2; +} + +enum MessageType { + None = 0; + AppendEntriesMsg = 1; + AppendEntriesResponseMsg = 2; + RequestVoteMsg = 3; + RequestVoteResponseMsg = 4; + DirectToLeaderMsg = 5; +} + diff --git a/platform/consensus/recovery/BUILD b/platform/consensus/recovery/BUILD index d429c6a7e8..e9e74755ad 100644 --- a/platform/consensus/recovery/BUILD +++ b/platform/consensus/recovery/BUILD @@ -19,17 +19,25 @@ package(default_visibility = ["//platform/consensus:__subpackages__"]) cc_library( - name = "recovery", - srcs = ["recovery.cpp"], - hdrs = ["recovery.h"], + name = "recovery_base", + hdrs = ["recovery.h", "recovery_impl.h"], deps = [ "//chain/storage", "//common/utils", "//platform/config:resdb_config", "//platform/consensus/checkpoint", - "//platform/consensus/execution:system_info", "//platform/networkstrate:server_comm", "//platform/proto:resdb_cc_proto", + ], +) + +cc_library( + name = "recovery", + srcs = ["pbft_recovery.cpp"], + hdrs = ["pbft_recovery.h"], + deps = [ + ":recovery_base", + "//platform/consensus/execution:system_info", "//platform/proto:system_info_data_cc_proto", ], ) @@ -45,3 +53,50 @@ cc_test( "//platform/consensus/ordering/common:transaction_utils", ], ) + +cc_library( + name = "mock_raft_recovery", + hdrs = ["mock_raft_recovery.h"], + testonly = True, + deps = [ + ":raft_recovery", + "//chain/storage:mock_storage", + "//platform/consensus/checkpoint:mock_checkpoint" + ], +) + +cc_library( + name = "raft_recovery", + srcs = ["raft_recovery.cpp"], + hdrs = ["raft_recovery.h"], + copts = ["-DRAFT_TEST_MODE"], + deps = [ + "//chain/storage", + "//common/utils", + "//platform/consensus/ordering/raft/proto:proposal_cc_proto", + "//platform/config:resdb_config", + "//platform/consensus/ordering/raft/framework:raft_checkpoint_manager", + "//platform/networkstrate:server_comm", + "//platform/proto:resdb_cc_proto", + "//platform/consensus/recovery:recovery_base" + ], +) + +cc_test( + name = "raft_recovery_test", + srcs = [ + "raft_recovery_test.cpp", + ], + copts = ["-DRAFT_TEST_MODE"], + deps = [ + ":raft_recovery", + "//chain/storage:mock_storage", + "//platform/consensus/ordering/raft/proto:proposal_cc_proto", + "//platform/consensus/checkpoint:mock_checkpoint", + "//platform/consensus/ordering/common:transaction_utils", + "//common/test:test_main", + "//platform/proto:client_test_cc_proto", + "//platform/consensus/ordering/raft/algorithm:raft" + ], + size="small" +) diff --git a/platform/consensus/recovery/mock_raft_recovery.h b/platform/consensus/recovery/mock_raft_recovery.h new file mode 100644 index 0000000000..136997263a --- /dev/null +++ b/platform/consensus/recovery/mock_raft_recovery.h @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +#include "platform/consensus/recovery/raft_recovery.h" +#include "platform/consensus/checkpoint/mock_checkpoint.h" +#include "chain/storage/mock_storage.h" + +namespace resdb { +namespace raft { + +class MockRaftRecovery : public RaftRecovery { + public: + MockRaftRecovery(const ResDBConfig& config) + : RaftRecovery(config, mock_checkpoint_.get(), mock_storage_.get(), + nullptr) {} + + MOCK_METHOD(void, AddLogEntry, (const Entry* entry), ()); + MOCK_METHOD(void, WriteMetadata, + (int64_t current_term, int32_t voted_for, + uint64_t snapshot_last_index, uint64_t snapshot_last_term), + ()); + MOCK_METHOD(void, AddLogEntry, (std::vector& entries_to_add), ()); + MOCK_METHOD(void, TruncateLog, (TruncationRecord truncate_beginning_at), ()); + + std::unique_ptr mock_checkpoint_; + std::unique_ptr mock_storage_; +}; + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/recovery/pbft_recovery.cpp b/platform/consensus/recovery/pbft_recovery.cpp new file mode 100644 index 0000000000..65e893ed0a --- /dev/null +++ b/platform/consensus/recovery/pbft_recovery.cpp @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "platform/consensus/recovery/pbft_recovery.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "common/utils/utils.h" + +namespace resdb { + +using CallbackType = + std::function, std::unique_ptr)>; + +PBFTRecovery::PBFTRecovery(const ResDBConfig& config, CheckPoint* checkpoint, + SystemInfo* system_info, Storage* storage) + : RecoveryBase( + config, checkpoint, storage), + system_info_(system_info) { + Init(); +} + +void PBFTRecovery::Init() { + if (recovery_enabled_ == false) { + LOG(INFO) << "recovery is not enabled:" << recovery_enabled_; + return; + } + + LOG(ERROR) << " init"; + GetLastFile(); + + CallbackType callback = [this](std::unique_ptr context, + std::unique_ptr request) { + min_seq_ = (min_seq_ == -1) + ? request->seq() + : std::min(min_seq_, static_cast(request->seq())); + max_seq_ = std::max(max_seq_, static_cast(request->seq())); + }; + + SwitchFile(file_path_, callback); + + LOG(ERROR) << " init done"; + + ckpt_thread_ = std::thread([this] { this->UpdateStableCheckPoint(); }); +} + +void PBFTRecovery::WriteSystemInfo() { + int view = system_info_->GetCurrentView(); + int primary_id = system_info_->GetPrimaryId(); + LOG(ERROR) << "write system info:" << primary_id << " view:" << view; + SystemInfoData data; + data.set_view(view); + data.set_primary_id(primary_id); + + std::string data_str; + data.SerializeToString(&data_str); + + AppendData(data_str); + Flush(); +} + +void PBFTRecovery::AddRequest(const Context* context, const Request* request) { + if (recovery_enabled_ == false) { + return; + } + switch (request->type()) { + case Request::TYPE_PRE_PREPARE: + case Request::TYPE_PREPARE: + case Request::TYPE_COMMIT: + case Request::TYPE_NEWVIEW: + return WriteLog(context, request); + default: + break; + } +} + +void PBFTRecovery::WriteLog(const Context* context, const Request* request) { + std::string data; + if (request) { + request->SerializeToString(&data); + } + + std::string sig; + if (context) { + context->signature.SerializeToString(&sig); + } + + std::unique_lock lk(mutex_); + min_seq_ = min_seq_ == -1 + ? request->seq() + : std::min(min_seq_, static_cast(request->seq())); + max_seq_ = std::max(max_seq_, static_cast(request->seq())); + AppendData(data); + AppendData(sig); + + Flush(); +} + +std::vector> +PBFTRecovery::ParseDataListItem(std::vector& data_list) { + std::vector> request_list; + + for (size_t i = 0; i < data_list.size(); i += 2) { + std::unique_ptr recovery_data = + std::make_unique(); + recovery_data->request = std::make_unique(); + recovery_data->context = std::make_unique(); + + if (!recovery_data->request->ParseFromString(data_list[i])) { + LOG(ERROR) << "Parse from data fail"; + break; + } + + if (!recovery_data->context->signature.ParseFromString(data_list[i + 1])) { + LOG(ERROR) << "Parse from data fail"; + break; + } + + request_list.push_back(std::move(recovery_data)); + } + return request_list; +} + +void PBFTRecovery::PerformCallback( + std::vector>& request_list, + CallbackType call_back, int64_t ckpt) { + uint64_t max_seq = 0; + for (std::unique_ptr& recovery_data : request_list) { + // LOG(ERROR)<<" ckpt :"<request->seq()<<" + // type:"<request->type(); + if (ckpt < recovery_data->request->seq() || + recovery_data->request->type() == Request::TYPE_NEWVIEW) { + recovery_data->request->set_is_recovery(true); + max_seq = recovery_data->request->seq(); + call_back(std::move(recovery_data->context), + std::move(recovery_data->request)); + } + } + + LOG(ERROR) << " recovery max seq:" << max_seq; +} + +bool PBFTRecovery::PerformSystemCallback( + std::vector data_list, + std::function system_callback) { + SystemInfoData info; + if (data_list.empty() || !info.ParseFromString(data_list[0])) { + return false; + } + LOG(ERROR) << "read system info:" << info.DebugString(); + system_callback(info); + return true; +} + +void PBFTRecovery::HandleSystemInfo( + int fd, std::function system_callback) { + size_t data_len = 0; + Read(fd, sizeof(data_len), reinterpret_cast(&data_len)); + std::string data; + char* buf = new char[data_len]; + if (!Read(fd, data_len, buf)) { + LOG(ERROR) << "Read system info fail"; + return; + } + data = std::string(buf, data_len); + delete buf; + std::vector data_list = ParseRawData(data); + + bool successful_callback = PerformSystemCallback(data_list, system_callback); + + if (!successful_callback) { + LOG(ERROR) << "parse info fail:" << data.size(); + } +} + +std::map< + uint64_t, + std::vector, std::unique_ptr>>> +PBFTRecovery::GetDataFromRecoveryFiles(uint64_t need_min_seq, + uint64_t need_max_seq) { + auto list = GetSortedRecoveryFiles(need_min_seq, need_max_seq); + + std::map, + std::unique_ptr>>> + res; + for (const auto& path : list) { + CallbackType callback = [&](std::unique_ptr context, + std::unique_ptr request) { + if (request->seq() >= need_min_seq && request->seq() <= need_max_seq) { + LOG(ERROR) << "get data from recovery file seq:" << request->seq(); + res[request->seq()].push_back( + std::make_pair(std::move(context), std::move(request))); + } + }; + + ReadLogsFromFiles( + path.second, need_min_seq - 1, 0, [&](const SystemInfoData& data) {}, + callback); + } + + return res; +} + +int PBFTRecovery::GetData(const RecoveryRequest& request, + RecoveryResponse& response) { + auto res = GetDataFromRecoveryFiles(request.min_seq(), request.max_seq()); + + for (const auto& it : res) { + for (const auto& req : it.second) { + *response.add_signature() = req.first->signature; + *response.add_request() = *req.second; + } + } + return 0; +} + +} // namespace resdb diff --git a/platform/consensus/recovery/pbft_recovery.h b/platform/consensus/recovery/pbft_recovery.h new file mode 100644 index 0000000000..6925b8ba16 --- /dev/null +++ b/platform/consensus/recovery/pbft_recovery.h @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include "platform/consensus/execution/system_info.h" +#include "platform/consensus/recovery/recovery.h" +#include "platform/proto/system_info_data.pb.h" + +namespace resdb { +using CallbackType = + std::function, std::unique_ptr)>; + +class PBFTRecovery + : public RecoveryBase { + friend class RecoveryBase; + + public: + PBFTRecovery(const ResDBConfig& config, CheckPoint* checkpoint, + SystemInfo* system_info, Storage* storage); + ~PBFTRecovery() = default; + + void AddRequest(const Context* context, const Request* request); + + std::map, + std::unique_ptr>>> + GetDataFromRecoveryFiles(uint64_t need_min_seq, uint64_t need_max_seq); + + int GetData(const RecoveryRequest& request, RecoveryResponse& response); + + private: + struct RecoveryData { + std::unique_ptr context; + std::unique_ptr request; + }; + + void Init(); + void WriteLog(const Context* context, const Request* request); + void WriteSystemInfo(); + + std::vector> ParseDataListItem( + std::vector& data_list); + + void PerformCallback(std::vector>& request_list, + CallbackType call_back, int64_t ckpt); + + bool PerformSystemCallback( + std::vector data_list, + std::function system_callback); + + void HandleSystemInfo( + int fd, std::function system_callback); + + SystemInfo* system_info_; +}; + +} // namespace resdb diff --git a/platform/consensus/recovery/raft_recovery.cpp b/platform/consensus/recovery/raft_recovery.cpp new file mode 100644 index 0000000000..b0d5d421df --- /dev/null +++ b/platform/consensus/recovery/raft_recovery.cpp @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "platform/consensus/recovery/raft_recovery.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "common/utils/utils.h" + +namespace resdb { +namespace raft { + +using CallbackType = std::function)>; + +RaftRecovery::RaftRecovery(const ResDBConfig& config, CheckPoint* checkpoint, + Storage* storage, + std::function on_checkpoint) + : RecoveryBase( + config, checkpoint, storage, on_checkpoint) { + Init(); +} + +void RaftRecovery::Init() { + if (recovery_enabled_ == false) { + LOG(INFO) << "recovery is not enabled:" << recovery_enabled_; + return; + } + + LOG(ERROR) << " init"; + GetLastFile(); + + meta_file_path_ = std::filesystem::path(base_file_path_).parent_path() / + "raft_metadata.dat"; + LOG(INFO) << "Meta file path: " << meta_file_path_; + OpenMetadataFile(); + + CallbackType callback = [this](std::unique_ptr record) { + min_seq_ = min_seq_ == -1 + ? record->seq() + : std::min(min_seq_, static_cast(record->seq())); + max_seq_ = std::max(max_seq_, static_cast(record->seq())); + }; + + SwitchFile(file_path_, callback); + LOG(ERROR) << " init done"; + + ckpt_thread_ = std::thread([this] { this->UpdateStableCheckPoint(); }); +} + +RaftRecovery::~RaftRecovery() { + if (recovery_enabled_ == false) { + return; + } + Flush(); + if (metadata_fd_ >= 0) { + close(metadata_fd_); + } +} + +void RaftRecovery::OpenMetadataFile() { + LOG(INFO) << "Opening Metadata File"; + metadata_fd_ = open(meta_file_path_.c_str(), O_CREAT | O_RDWR, 0666); + if (metadata_fd_ < 0) { + LOG(ERROR) << "Failed to open metadata file: " << strerror(errno); + return; + } +} + +void RaftRecovery::WriteMetadata(int64_t current_term, int32_t voted_for, + uint64_t snapshot_last_index, + uint64_t snapshot_last_term) { + if (recovery_enabled_ == false) { + return; + } + + std::string tmp_path = meta_file_path_ + ".tmp"; + LOG(INFO) << "tmp_path = [" << tmp_path << "]"; + LOG(INFO) << "meta_file_path_ = [" << meta_file_path_ << "]"; + + int temp_fd = open(tmp_path.c_str(), O_CREAT | O_WRONLY | O_TRUNC, 0666); + if (temp_fd < 0) { + LOG(ERROR) << "Failed to open tmp metadata file: " << strerror(errno); + return; + } + + RaftMetadata new_metadata; + new_metadata.current_term = current_term; + new_metadata.voted_for = voted_for; + new_metadata.snapshot_last_index = snapshot_last_index; + new_metadata.snapshot_last_term = snapshot_last_term; + + ssize_t bytes_written = write(temp_fd, &new_metadata, sizeof(new_metadata)); + if (bytes_written != static_cast(sizeof(new_metadata))) { + LOG(ERROR) << "Failed to write metadata (wrote " << bytes_written << " of " + << sizeof(new_metadata) << " bytes): " << strerror(errno); + close(temp_fd); + unlink(tmp_path.c_str()); + return; + } + + if (fsync(temp_fd) < 0) { + LOG(ERROR) << "Failed to fsync tmp metadata file: " << strerror(errno); + close(temp_fd); + unlink(tmp_path.c_str()); + return; + } + close(temp_fd); + + if (rename(tmp_path.c_str(), meta_file_path_.c_str()) < 0) { + LOG(ERROR) << "Failed to rename tmp metadata file: " << strerror(errno); + unlink(tmp_path.c_str()); + return; + } + + // Only fsync and close the dir if it opens properly + std::string dir_path = std::filesystem::path(meta_file_path_).parent_path().string(); + int dir_fd = open(dir_path.c_str(), O_RDONLY); + if (dir_fd < 0) { + LOG(ERROR) << "Failed to open directory for fsync: " << strerror(errno); + } else { + if (fsync(dir_fd) < 0) { + LOG(ERROR) << "Failed to fsync directory: " << strerror(errno); + } + close(dir_fd); + } + + metadata_ = new_metadata; + + LOG(INFO) << "Wrote metadata: term: " << current_term + << " votedFor: " << voted_for + << " snapshot last index: " << snapshot_last_index + << " snapshot last term: " << snapshot_last_term; + LOG(INFO) << "METADATA location: " << meta_file_path_; +} + +RaftMetadata RaftRecovery::ReadMetadata() { + if (recovery_enabled_ == false) { + return RaftMetadata{}; + } + + RaftMetadata metadata; + if (metadata_fd_ < 0) { + LOG(ERROR) << "Metadata file either never opened or already closed " + "(meaning ReadMetadata() has been called before)"; + return metadata; + } + + lseek(metadata_fd_, 0, SEEK_SET); + int bytes = read(metadata_fd_, &metadata, sizeof(metadata)); + if (bytes != sizeof(metadata)) { + LOG(INFO) << "No existing metadata, using defaults"; + return RaftMetadata{}; + } + + LOG(INFO) << "Read metadata file: term: " << metadata.current_term + << " votedFor: " << metadata.voted_for + << " snapshot_last_index: " << metadata.snapshot_last_index + << " snapshot_last_term: " << metadata.snapshot_last_term; + return metadata; +} + +void RaftRecovery::WriteSystemInfo() {} + +void RaftRecovery::AddLogEntry(const Entry* entry, int64_t seq) { + if (recovery_enabled_ == false) { + return; + } + + std::unique_lock lk(mutex_); + WALRecord record; + *record.mutable_entry() = *entry; + record.set_seq(seq); + WriteLog(record); + Flush(); +} + +void RaftRecovery::AddLogEntry(std::vector& entries_to_add, + int64_t seq) { + if (recovery_enabled_ == false || entries_to_add.size() == 0) { + return; + } + + std::unique_lock lk(mutex_); + for (const auto& entry : entries_to_add) { + WALRecord record; + *record.mutable_entry() = entry; + record.set_seq(seq++); + WriteLog(record); + } + Flush(); +} + +void RaftRecovery::TruncateLog(TruncationRecord truncate_beginning_at) { + if (recovery_enabled_ == false) { + return; + } + + std::unique_lock lk(mutex_); + + WALRecord record; + record.set_seq(truncate_beginning_at.truncate_from_index() - 1); + *record.mutable_truncation() = std::move(truncate_beginning_at); + + WriteLog(record); + Flush(); +} + +void RaftRecovery::WriteLog(const WALRecord& record) { + std::string data; + + record.SerializeToString(&data); + + + switch (record.payload_case()) { + case WALRecord::kEntry: { + min_seq_ = min_seq_ == -1 + ? record.seq() + : std::min(min_seq_, static_cast(record.seq())); + max_seq_ = std::max(max_seq_, static_cast(record.seq())); + break; + } + case WALRecord::kTruncation: { + int64_t keep_up_to = static_cast(record.seq()); + if (max_seq_ > keep_up_to) { + max_seq_ = keep_up_to; + } + // If we truncate everything, reset min and max seq + if (max_seq_ <= last_ckpt_) { + min_seq_ = -1; + max_seq_ = -1; + } else { + min_seq_ = + (min_seq_ == -1) ? keep_up_to : std::min(min_seq_, keep_up_to); + } + break; + } + case WALRecord::PAYLOAD_NOT_SET: { + assert(false && "WALRecord does not contain Truncation or Entry"); + break; + } + } + + AppendData(data); +} + +std::vector> RaftRecovery::ParseDataListItem( + std::vector& data_list) { + std::vector> record_list; + + for (size_t i = 0; i < data_list.size(); i++) { + std::unique_ptr record = std::make_unique(); + + if (!record->ParseFromString(data_list[i])) { + LOG(ERROR) << "Parse from data fail"; + break; + } + + record_list.push_back(std::move(record)); + } + return record_list; +} + +void RaftRecovery::PerformCallback( + std::vector>& record_list, CallbackType call_back, + int64_t ckpt) { + uint64_t max_seq = 0; + for (std::unique_ptr& record : record_list) { + // Only replay entries that are after the latest checkpoint. + // Since truncation records store the seq of the last index remaining in the + // log, it could be equal to the ckpt, meaning that everything since the + // checkpoint is to be truncated. + if (ckpt < record->seq() || + (ckpt == record->seq() && + record->payload_case() == WALRecord::kTruncation)) { + max_seq = record->seq(); + call_back(std::move(record)); + } + } + + LOG(ERROR) << " recovery max seq:" << max_seq; +} + +void RaftRecovery::HandleSystemInfo( + int /*fd*/, std::function system_callback) { + metadata_ = ReadMetadata(); + LOG(ERROR) << " metadata_.voted_for: " << metadata_.voted_for + << "\nmetadata_.current_term " << metadata_.current_term; + system_callback(metadata_); +} + +} // namespace raft + +} // namespace resdb diff --git a/platform/consensus/recovery/raft_recovery.h b/platform/consensus/recovery/raft_recovery.h new file mode 100644 index 0000000000..0ab7ffaa97 --- /dev/null +++ b/platform/consensus/recovery/raft_recovery.h @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +#include "chain/storage/storage.h" +#include "platform/config/resdb_config.h" +#include "platform/consensus/checkpoint/checkpoint.h" +#include "platform/consensus/ordering/raft/framework/raft_checkpoint_manager.h" +#include "platform/consensus/ordering/raft/proto/proposal.pb.h" +#include "platform/consensus/recovery/recovery.h" +#include "platform/networkstrate/server_comm.h" +#include "platform/proto/resdb.pb.h" + +namespace resdb { + +namespace raft { + +struct RaftMetadata { + int64_t current_term = 0; + int32_t voted_for = -1; + uint64_t snapshot_last_index = 0; + uint64_t snapshot_last_term = 0; +}; + +using CallbackType = std::function)>; + +class RaftRecovery + : public RecoveryBase { + friend class RecoveryBase; + + public: + RaftRecovery(const ResDBConfig& config, CheckPoint* checkpoint, + Storage* storage, std::function on_checkpoint); + ~RaftRecovery(); + + RaftMetadata ReadMetadata(); + void Init(); + void WriteMetadata(int64_t current_term, int32_t voted_for, + uint64_t snapshot_last_index, uint64_t snapshot_last_term); + void AddLogEntry(const Entry* entry, int64_t seq); + void AddLogEntry(std::vector& entries_to_add, int64_t seq); + void TruncateLog(TruncationRecord truncate_beginning_at); + +#ifdef RAFT_RECOVERY_TEST_MODE + std::string GetMetadataFilePath() { return meta_file_path_; } + + std::string GetFilePath() { return file_path_; } +#endif + + private: + void OpenMetadataFile(); + void WriteSystemInfo(); + std::vector> ParseDataListItem( + std::vector& data_list); + void WriteLog(const WALRecord& record); + + void PerformCallback( + std::vector>& request_list, + std::function record)> call_back, + int64_t ckpt); + + void HandleSystemInfo( + int /*fd*/, std::function system_callback); + + int metadata_fd_; + std::string meta_file_path_; + RaftMetadata metadata_; +}; + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/recovery/raft_recovery_test.cpp b/platform/consensus/recovery/raft_recovery_test.cpp new file mode 100644 index 0000000000..3bcabfa292 --- /dev/null +++ b/platform/consensus/recovery/raft_recovery_test.cpp @@ -0,0 +1,800 @@ +#include "platform/consensus/recovery/raft_recovery.h" + +#include +#include +#include + +#include +#include + +#include "chain/storage/mock_storage.h" +#include "platform/consensus/checkpoint/mock_checkpoint.h" +#include "platform/consensus/ordering/common/transaction_utils.h" +#include "platform/consensus/ordering/raft/proto/proposal.pb.h" + +namespace resdb { +namespace raft { +using ::testing::_; +using ::testing::AnyNumber; +using ::testing::Invoke; +using ::testing::Matcher; +using ::testing::Return; +using ::testing::Test; + +const std::string log_path = "./log/test_log"; + +ResConfigData GetConfigData(int buf_size = 10) { + ResConfigData data; + data.set_recovery_enabled(true); + data.set_recovery_path(log_path); + data.set_recovery_buffer_size(buf_size); + data.set_recovery_ckpt_time_s(1); + + return data; +} + +std::vector Listlogs(const std::string &path) { + std::vector ret; + std::string dir = std::filesystem::path(path).parent_path(); + for (const auto &entry : std::filesystem::directory_iterator(dir)) { + LOG(ERROR) << "path:" << entry.path(); + ret.push_back(entry.path()); + } + return ret; +} + +static Entry CreateTestEntry(RaftRecovery &recovery, int term, int seq) { + Entry logEntry; + logEntry.set_term(term); + auto req = std::make_unique(); + req->set_seq(seq); + req->set_data("Request " + std::to_string(seq)); + std::string serialized; + EXPECT_TRUE(req->SerializeToString(&serialized)); + logEntry.set_command(std::move(serialized)); + return logEntry; +} + +static void AddTestEntry(RaftRecovery &recovery, int term, int seq) { + Entry logEntry = CreateTestEntry(recovery, term, seq); + recovery.AddLogEntry(&logEntry, seq); +} + +class RaftRecoveryTest : public Test { + public: + RaftRecoveryTest() + : config_(GetConfigData(), ReplicaInfo(), KeyInfo(), CertificateInfo()) { + std::string dir = std::filesystem::path(log_path).parent_path(); + std::filesystem::remove_all(dir); + } + + protected: + ResDBConfig config_; + MockCheckPoint checkpoint_; +}; + +TEST_F(RaftRecoveryTest, WriteAndReadLog) { + int entries_to_add = 3; + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + for (int i = 1; i <= entries_to_add; i++) { + AddTestEntry(recovery, i, i); + } + } + { + std::vector list; + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &data) {}, + [&](std::unique_ptr record) { list.push_back(*record); }, + nullptr); + + EXPECT_EQ(list.size(), entries_to_add); + + for (size_t i = 0; i < entries_to_add; ++i) { + EXPECT_EQ(list[i].payload_case(), WALRecord::kEntry); + + EXPECT_EQ(list[i].entry().term(), i + 1); + Request req; + req.ParseFromString(list[i].entry().command()); + EXPECT_EQ(req.data(), "Request " + std::to_string(i + 1)); + } + } +} + +TEST_F(RaftRecoveryTest, WriteMultipleEntriesAndReadLog) { + int entries_to_add = 3; + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + std::vector log_entries; + for (int i = 1; i <= entries_to_add; i++) { + log_entries.push_back(CreateTestEntry(recovery, i, i)); + } + recovery.AddLogEntry(log_entries, 1); + } + { + std::vector list; + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &data) {}, + [&](std::unique_ptr record) { list.push_back(*record); }, nullptr); + + EXPECT_EQ(list.size(), entries_to_add); + + for (size_t i = 0; i < entries_to_add; ++i) { + EXPECT_EQ(list[i].payload_case(), WALRecord::kEntry); + + EXPECT_EQ(list[i].entry().term(), i + 1); + Request req; + req.ParseFromString(list[i].entry().command()); + EXPECT_EQ(req.data(), "Request " + std::to_string(i + 1)); + } + } +} + +TEST_F(RaftRecoveryTest, WriteAndReadMetadata) { + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + recovery.WriteMetadata(2, 3, 100, 1); + } + { + int64_t current_term; + int32_t voted_for; + uint64_t snapshot_last_index; + uint64_t snapshot_last_term; + + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &data) { + current_term = data.current_term; + voted_for = data.voted_for; + snapshot_last_index = data.snapshot_last_index; + snapshot_last_term = data.snapshot_last_term; + }, + [&](std::unique_ptr record) {}, nullptr); + + EXPECT_EQ(current_term, 2); + EXPECT_EQ(voted_for, 3); + EXPECT_EQ(snapshot_last_index, 100); + EXPECT_EQ(snapshot_last_term, 1); + } +} + +TEST_F(RaftRecoveryTest, WriteAndReadMetadataTwice) { + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + recovery.WriteMetadata(2, 3, 100, 1); + recovery.WriteMetadata(4, 2, 200, 2); + } + { + int64_t current_term; + int32_t voted_for; + uint64_t snapshot_last_index; + uint64_t snapshot_last_term; + + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &data) { + current_term = data.current_term; + voted_for = data.voted_for; + snapshot_last_index = data.snapshot_last_index; + snapshot_last_term = data.snapshot_last_term; + }, + [&](std::unique_ptr record) {}, nullptr); + + EXPECT_EQ(current_term, 4); + EXPECT_EQ(voted_for, 2); + EXPECT_EQ(snapshot_last_index, 200); + EXPECT_EQ(snapshot_last_term, 2); + } +} + +TEST_F(RaftRecoveryTest, ReadMetadataDefaultValues) { + { + int64_t current_term; + int32_t voted_for; + uint64_t snapshot_last_index; + uint64_t snapshot_last_term; + + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &data) { + current_term = data.current_term; + voted_for = data.voted_for; + snapshot_last_index = data.snapshot_last_index; + snapshot_last_term = data.snapshot_last_term; + }, + [&](std::unique_ptr record) {}, nullptr); + + EXPECT_EQ(current_term, 0); + EXPECT_EQ(voted_for, -1); + EXPECT_EQ(snapshot_last_index, 0); + EXPECT_EQ(snapshot_last_term, 0); + } +} + +TEST_F(RaftRecoveryTest, TruncateLog) { + int entries_to_add = 4; + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + for (int i = 1; i <= entries_to_add; i++) { + AddTestEntry(recovery, i, i); + } + + TruncationRecord truncation; + truncation.set_truncate_from_index(3); + truncation.set_truncate_from_term(3); + recovery.TruncateLog(truncation); + + for (int i = 5; i <= entries_to_add * 2; i++) { + AddTestEntry(recovery, i + 1, i); + } + } + /* Recovery WAL + Term Seq Data + list[0] 1 1 Request 1 + list[1] 2 2 Request 2 + list[2] 3 3 Request 3 + list[3] 4 4 Request 4 + list[4] Truncate beginning at Seq 3 + list[5] 6 5 Request 5 + list[6] 7 6 Request 6 + list[7] 8 7 Request 7 + list[8] 9 8 Request 8 + */ + { + std::vector list; + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &data) {}, + [&](std::unique_ptr record) { list.push_back(*record); }, + nullptr); + + EXPECT_EQ(list.size(), 2 * entries_to_add + 1); + + for (size_t i = 0; i < entries_to_add; ++i) { + EXPECT_EQ(list[i].payload_case(), WALRecord::kEntry); + EXPECT_EQ(list[i].entry().term(), i + 1); + Request req; + req.ParseFromString(list[i].entry().command()); + EXPECT_EQ(req.data(), "Request " + std::to_string(i + 1)); + EXPECT_EQ(req.seq(), i + 1); + } + + EXPECT_EQ(list[4].payload_case(), WALRecord::kTruncation); + EXPECT_EQ(list[4].truncation().truncate_from_index(), 3); + + for (size_t i = entries_to_add + 1; i < 2 * entries_to_add + 1; ++i) { + EXPECT_EQ(list[i].payload_case(), WALRecord::kEntry); + EXPECT_EQ(list[i].entry().term(), i + 1); + Request req; + req.ParseFromString(list[i].entry().command()); + EXPECT_EQ(req.data(), "Request " + std::to_string(i)); + EXPECT_EQ(req.seq(), i); + } + } +} + +// After a checkpoint fires and the log file is rotated, there should be exactly +// two .log files on disk: the sealed (checkpointed) file and the new active +// one. +TEST_F(RaftRecoveryTest, CheckpointCreatesNewLogFile) { + std::promise insert_done, ckpt_fired; + auto insert_done_future = insert_done.get_future(); + auto ckpt_fired_future = ckpt_fired.get_future(); + + int call_count = 0; + EXPECT_CALL(checkpoint_, GetStableCheckpoint()) + .WillRepeatedly(Invoke([&]() -> uint64_t { + ++call_count; + if (call_count == 1) + insert_done_future.get(); + else if (call_count == 2) + ckpt_fired.set_value(true); + return 5; + })); + + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + for (int i = 1; i <= 9; i++) { + AddTestEntry(recovery, i, i); + } + insert_done.set_value(true); + ckpt_fired_future.get(); + + // Write some more entries into the new file. + for (int i = 10; i <= 18; i++) { + AddTestEntry(recovery, i, i); + } + } + + std::vector log_list = Listlogs(log_path); + // 2 log files and one metadata file + EXPECT_EQ(log_list.size(), 3); +} + +// After a checkpoint at stable_seq=5, ReadLogs should only replay WAL records +// whose seq is strictly greater than 5. +TEST_F(RaftRecoveryTest, CheckpointFiltersOldEntries) { + std::promise insert_done, ckpt_fired; + auto insert_done_future = insert_done.get_future(); + auto ckpt_fired_future = ckpt_fired.get_future(); + + int call_count = 0; + EXPECT_CALL(checkpoint_, GetStableCheckpoint()) + .WillRepeatedly(Invoke([&]() -> uint64_t { + ++call_count; + if (call_count == 1) + insert_done_future.get(); + else if (call_count == 2) + ckpt_fired.set_value(true); + return 5; + })); + + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + for (int i = 1; i <= 9; i++) { + AddTestEntry(recovery, i, i); + } + insert_done.set_value(true); + ckpt_fired_future.get(); + } + + { + std::vector list; + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &) {}, + [&](std::unique_ptr record) { list.push_back(*record); }, + nullptr); + + // Only WAL seqs 6-9 should be replayed (4 entries). + ASSERT_EQ(list.size(), 4u); + for (size_t i = 0; i < list.size(); ++i) { + EXPECT_EQ(list[i].payload_case(), WALRecord::kEntry); + Request req; + req.ParseFromString(list[i].entry().command()); + EXPECT_EQ(req.seq(), (int)(i + 6)); + } + } +} + +// After a checkpoint rotation, GetMinSeq()/GetMaxSeq() should reset to -1 for +// the newly opened (empty) file, then update as new entries are appended. +TEST_F(RaftRecoveryTest, CheckpointResetsMinMaxSeq) { + std::promise insert_done, ckpt_fired; + auto insert_done_future = insert_done.get_future(); + auto ckpt_fired_future = ckpt_fired.get_future(); + + int call_count = 0; + EXPECT_CALL(checkpoint_, GetStableCheckpoint()) + .WillRepeatedly(Invoke([&]() -> uint64_t { + ++call_count; + if (call_count == 1) + insert_done_future.get(); + else if (call_count == 2) + ckpt_fired.set_value(true); + return 5; + })); + + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + for (int i = 1; i <= 5; i++) { + AddTestEntry(recovery, i, i); + } + insert_done.set_value(true); + ckpt_fired_future.get(); + + EXPECT_EQ(recovery.GetMinSeq(), -1); + EXPECT_EQ(recovery.GetMaxSeq(), -1); + + // Add entries to the new file and verify the range is tracked correctly. + for (int i = 6; i <= 9; i++) { + AddTestEntry(recovery, i, i); + } + + EXPECT_EQ(recovery.GetMinSeq(), 6); + EXPECT_EQ(recovery.GetMaxSeq(), 9); + } +} + +// Two successive checkpoints. After both fires, only entries whose WAL seq +// exceeds the second checkpoint value (15) survive replay. +TEST_F(RaftRecoveryTest, TwoCheckpoints) { + std::promise ins1, ck1, ins2, ck2; + auto ins1f = ins1.get_future(), ck1f = ck1.get_future(); + auto ins2f = ins2.get_future(), ck2f = ck2.get_future(); + + int call_count = 0; + EXPECT_CALL(checkpoint_, GetStableCheckpoint()) + .WillRepeatedly(Invoke([&]() -> uint64_t { + ++call_count; + if (call_count == 1) + ins1f.get(); + else if (call_count == 2) + ck1.set_value(true); + else if (call_count == 3) + ins2f.get(); + else if (call_count == 4) + ck2.set_value(true); + return (call_count <= 2) ? 5 : 15; + })); + + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + for (int i = 1; i <= 9; i++) { + AddTestEntry(recovery, i, i); + } + ins1.set_value(true); + ck1f.get(); + + for (int i = 10; i <= 18; i++) { + AddTestEntry(recovery, i, i); + } + ins2.set_value(true); + ck2f.get(); + + // Third window: entries 19-22. + for (int i = 19; i <= 22; i++) { + AddTestEntry(recovery, i, i); + } + } + + std::vector log_list = Listlogs(log_path); + // 3 log files and one metadata file + EXPECT_EQ(log_list.size(), 4); + + { + std::vector list; + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &) {}, + [&](std::unique_ptr record) { list.push_back(*record); }, + nullptr); + + // ckpt=15: entries with WAL seq > 15 survive: seqs 16-22 (7 entries). + ASSERT_EQ(list.size(), 7u); + for (size_t i = 0; i < list.size(); ++i) { + Request req; + req.ParseFromString(list[i].entry().command()); + EXPECT_EQ(req.seq(), (int)(i + 16)); + } + // Even though seqs 16-22 survive, min seq and max seq refer to the most + // recent log. + EXPECT_EQ(recovery.GetMinSeq(), 19); + EXPECT_EQ(recovery.GetMaxSeq(), 22); + } +} + +// Metadata lives in a separate file and should be fully preserved across log +// rotations caused by a checkpoint. +TEST_F(RaftRecoveryTest, MetadataPersistedAcrossCheckpoint) { + std::promise insert_done, ckpt_fired; + auto insert_done_future = insert_done.get_future(); + auto ckpt_fired_future = ckpt_fired.get_future(); + + int call_count = 0; + EXPECT_CALL(checkpoint_, GetStableCheckpoint()) + .WillRepeatedly(Invoke([&]() -> uint64_t { + ++call_count; + if (call_count == 1) + insert_done_future.get(); + else if (call_count == 2) + ckpt_fired.set_value(true); + return 5; + })); + + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.WriteMetadata(7, 2, 50, 3); + + for (int i = 1; i <= 5; i++) { + AddTestEntry(recovery, i, i); + } + insert_done.set_value(true); + ckpt_fired_future.get(); + + for (int i = 6; i <= 8; i++) { + AddTestEntry(recovery, i, i); + } + } + + { + int64_t current_term = 0; + int32_t voted_for = 0; + uint64_t snapshot_last_index = 0; + uint64_t snapshot_last_term = 0; + + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &data) { + current_term = data.current_term; + voted_for = data.voted_for; + snapshot_last_index = data.snapshot_last_index; + snapshot_last_term = data.snapshot_last_term; + }, + [&](std::unique_ptr) {}, nullptr); + + EXPECT_EQ(current_term, 7); + EXPECT_EQ(voted_for, 2); + EXPECT_EQ(snapshot_last_index, 50); + EXPECT_EQ(snapshot_last_term, 3); + } +} + +// When Storage::Flush() fails, FinishFile() bails out early and the log file +// must NOT be rotated — only one file should remain on disk. +TEST_F(RaftRecoveryTest, CheckpointNotFinalizedWhenStorageFlushFails) { + MockStorage storage; + EXPECT_CALL(storage, Flush).WillRepeatedly(Return(false)); + + std::promise insert_done, ckpt_fired; + auto insert_done_future = insert_done.get_future(); + auto ckpt_fired_future = ckpt_fired.get_future(); + + int call_count = 0; + EXPECT_CALL(checkpoint_, GetStableCheckpoint()) + .WillRepeatedly(Invoke([&]() -> uint64_t { + ++call_count; + if (call_count == 1) + insert_done_future.get(); + else if (call_count == 2) + ckpt_fired.set_value(true); + return 5; + })); + + { + RaftRecovery recovery(config_, &checkpoint_, &storage, nullptr); + + for (int i = 1; i <= 5; i++) { + AddTestEntry(recovery, i, i); + } + insert_done.set_value(true); + ckpt_fired_future.get(); + + for (int i = 6; i <= 8; i++) { + AddTestEntry(recovery, i, i); + } + } + + // The file should never have been renamed; only one .log file exists. + std::vector log_list = Listlogs(log_path); + // 1 log file and one metadata file + EXPECT_EQ(log_list.size(), 2); +} + +ResConfigData GetConfigDataNoRecovery(int buf_size = 10) { + ResConfigData data; + data.set_recovery_enabled(false); + data.set_recovery_path(log_path); + data.set_recovery_buffer_size(buf_size); + data.set_recovery_ckpt_time_s(1); + return data; +} + +// When recovery_enabled=false, all write operations are no-ops and the WAL +// directory is never created on disk. +TEST_F(RaftRecoveryTest, RecoveryDisabledNoOpsAndCreatesNoDirectory) { + ResDBConfig config(GetConfigDataNoRecovery(1024), ReplicaInfo(), KeyInfo(), + CertificateInfo()); + + const std::string log_dir = + std::filesystem::path(log_path).parent_path().string(); + + // Precondition: directory does not exist (the fixture removes it in SetUp). + ASSERT_FALSE(std::filesystem::exists(log_dir)); + + { + RaftRecovery recovery(config, &checkpoint_, nullptr, nullptr); + + // All of these must be silent no-ops. + for (int i = 1; i <= 5; ++i) { + AddTestEntry(recovery, i, i); + } + + recovery.WriteMetadata(7, 2, 50, 3); + + TruncationRecord trunc; + trunc.set_truncate_from_index(3); + trunc.set_truncate_from_term(2); + recovery.TruncateLog(trunc); + + // ReadLogs must also be a no-op and invoke neither callback. + bool metadata_cb_called = false; + bool record_cb_called = false; + recovery.ReadLogs( + [&](const RaftMetadata &) { metadata_cb_called = true; }, + [&](std::unique_ptr) { record_cb_called = true; }, nullptr); + + EXPECT_FALSE(metadata_cb_called); + EXPECT_FALSE(record_cb_called); + } + + // The WAL directory must never have been created. + EXPECT_FALSE(std::filesystem::exists(log_dir)) + << "WAL directory was created even though recovery is disabled"; +} + +// When recovery is disabled, ReadMetadata returns the zero-value struct. +TEST_F(RaftRecoveryTest, RecoveryDisabledReadMetadataReturnsDefaults) { + ResDBConfig config(GetConfigDataNoRecovery(1024), ReplicaInfo(), KeyInfo(), + CertificateInfo()); + + RaftRecovery recovery(config, &checkpoint_, nullptr, nullptr); + + RaftMetadata meta = recovery.ReadMetadata(); + EXPECT_EQ(meta.current_term, 0); + EXPECT_EQ(meta.voted_for, -1); + EXPECT_EQ(meta.snapshot_last_index, 0u); + EXPECT_EQ(meta.snapshot_last_term, 0u); +} + +// Truncation record seq == checkpoint value. +// +// Layout written to WAL: +// seq 1 – entry (term 1) +// seq 2 – entry (term 2) +// seq 3 – entry (term 3) +// seq 4 – entry (term 4) +// truncation with truncate_from_index=3 → stored at seq = 3-1 = 2 +// seq 3 – entry (term 13) +// seq 4 – entry (term 14) +// +// The checkpoint fires at seq=2, directly before the truncation. +// +// What survives: only records with seq > 2, i.e. the two post-truncation +// entries at seq 3 and 4. +TEST_F(RaftRecoveryTest, TruncationAtCheckpointBoundary) { + std::promise insert_done, ckpt_fired; + auto insert_done_f = insert_done.get_future(); + auto ckpt_fired_f = ckpt_fired.get_future(); + + int call_count = 0; + EXPECT_CALL(checkpoint_, GetStableCheckpoint()) + .WillRepeatedly(Invoke([&]() -> uint64_t { + ++call_count; + if (call_count == 1) + insert_done_f.get(); + else if (call_count == 2) + ckpt_fired.set_value(true); + // Checkpoint at 2 — the same seq as the truncation record. + return 2; + })); + + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + // Write entries 1–4 at seq 1–4. + for (int i = 1; i <= 4; ++i) { + AddTestEntry(recovery, i, i); + } + + // Truncate from index 3 → stored at seq = 2. + TruncationRecord trunc; + trunc.set_truncate_from_index(3); + trunc.set_truncate_from_term(2); + recovery.TruncateLog(trunc); + + // Write two replacement entries at seq 3–4 (new leader's branch). + for (int i = 3; i <= 4; ++i) { + AddTestEntry(recovery, 10 + i, i); + } + + insert_done.set_value(true); + ckpt_fired_f.get(); + // File is now sealed at ckpt=2. The active window starts fresh. + } + + { + std::vector list; + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &) {}, + [&](std::unique_ptr record) { list.push_back(*record); }, + nullptr); + + ASSERT_EQ(list.size(), 5u); + EXPECT_EQ(list[0].payload_case(), WALRecord::kEntry); + EXPECT_EQ(list[1].payload_case(), WALRecord::kEntry); + EXPECT_EQ(list[2].payload_case(), WALRecord::kTruncation); + EXPECT_EQ(list[3].payload_case(), WALRecord::kEntry); + EXPECT_EQ(list[4].payload_case(), WALRecord::kEntry); + + Request req3, req4, req3again, req4again; + req3.ParseFromString(list[0].entry().command()); + req4.ParseFromString(list[1].entry().command()); + EXPECT_EQ(req3.seq(), 3); + EXPECT_EQ(req4.seq(), 4); + EXPECT_EQ(list[0].entry().term(), 3); + EXPECT_EQ(list[1].entry().term(), 4); + + EXPECT_EQ(list[2].truncation().truncate_from_index(), 3); + + req3again.ParseFromString(list[3].entry().command()); + req4again.ParseFromString(list[4].entry().command()); + EXPECT_EQ(req3again.seq(), 3); + EXPECT_EQ(req4again.seq(), 4); + EXPECT_EQ(list[3].entry().term(), 13); + EXPECT_EQ(list[4].entry().term(), 14); + } +} + +// Truncation record seq BELOW checkpoint value: also dropped. +// +// Same layout but checkpoint fires at stable_seq=5 (above the truncation's +// seq=2). All records with seq ≤ 5 are behind the checkpoint; only seq 3 +// and 4 survive if they came from the pre-checkpoint file selected by +// GetRecoveryFiles. In this variant we check that no truncation record +// bleeds through in the surviving window. +TEST_F(RaftRecoveryTest, TruncationBelowCheckpointIsDropped) { + std::promise insert_done, ckpt_fired; + auto insert_done_f = insert_done.get_future(); + auto ckpt_fired_f = ckpt_fired.get_future(); + + int call_count = 0; + EXPECT_CALL(checkpoint_, GetStableCheckpoint()) + .WillRepeatedly(Invoke([&]() -> uint64_t { + ++call_count; + if (call_count == 1) + insert_done_f.get(); + else if (call_count == 2) + ckpt_fired.set_value(true); + return 5; + })); + + { + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + + for (int i = 1; i <= 4; ++i) { + AddTestEntry(recovery, i, i); + } + + TruncationRecord trunc; + trunc.set_truncate_from_index(3); + trunc.set_truncate_from_term(2); + recovery.TruncateLog(trunc); + + for (int i = 3; i <= 8; ++i) { + AddTestEntry(recovery, 10 + i, i); + } + + insert_done.set_value(true); + ckpt_fired_f.get(); + } + + { + std::vector list; + RaftRecovery recovery(config_, &checkpoint_, nullptr, nullptr); + recovery.ReadLogs( + [&](const RaftMetadata &) {}, + [&](std::unique_ptr record) { list.push_back(*record); }, + nullptr); + + // Entries at seq 6, 7, 8 survive (strictly > ckpt=5). + // The truncation at seq=2 is entirely behind the checkpoint and must not + // appear. + for (const auto &r : list) { + EXPECT_EQ(r.payload_case(), WALRecord::kEntry) + << "Truncation record below checkpoint leaked into replay"; + } + ASSERT_EQ(list.size(), 3u); + for (size_t i = 0; i < list.size(); ++i) { + Request req; + req.ParseFromString(list[i].entry().command()); + EXPECT_EQ(req.seq(), (int)(i + 6)); + } + } +} + +// TODO: Create tests that corrupt recovery files to test our handling of them. + +} // namespace raft +} // namespace resdb diff --git a/platform/consensus/recovery/recovery.h b/platform/consensus/recovery/recovery.h index 426bffd275..1517f6dd8a 100644 --- a/platform/consensus/recovery/recovery.h +++ b/platform/consensus/recovery/recovery.h @@ -19,85 +19,94 @@ #pragma once +#include +#include +#include +#include +#include +#include + +#include +#include +#include #include #include "chain/storage/storage.h" +#include "common/utils/utils.h" #include "platform/config/resdb_config.h" #include "platform/consensus/checkpoint/checkpoint.h" -#include "platform/consensus/execution/system_info.h" #include "platform/networkstrate/server_comm.h" #include "platform/proto/resdb.pb.h" -#include "platform/proto/system_info_data.pb.h" namespace resdb { -class Recovery { +template +class RecoveryBase { public: - Recovery(const ResDBConfig& config, CheckPoint* checkpoint, - SystemInfo* system_info, Storage* storage); - virtual ~Recovery(); + RecoveryBase(const ResDBConfig& config, CheckPoint* checkpoint, + Storage* storage, + std::function on_checkpoint = nullptr); + ~RecoveryBase(); - void Init(); - - virtual void AddRequest(const Context* context, const Request* request); - void ReadLogs(std::function system_callback, - std::function context, - std::unique_ptr request)> - call_back, - std::function start_point); + void ReadLogs( + std::function system_callback, + TCallback call_back, std::function start_point); int64_t GetMaxSeq(); int64_t GetMinSeq(); - int GetData(const RecoveryRequest& request, RecoveryResponse& response); + protected: + std::vector> GetSortedRecoveryFiles( + uint64_t need_min_seq, uint64_t need_max_seq); - std::map, - std::unique_ptr>>> - GetDataFromRecoveryFiles(uint64_t need_min_seq, uint64_t need_max_seq); + std::vector ParseRawData(const std::string& data); private: - struct RecoveryData { - std::unique_ptr context; - std::unique_ptr request; - }; + auto ParseData(const std::string& data); - void WriteLog(const Context* context, const Request* request); - void AppendData(const std::string& data); - std::vector> ParseData(const std::string& data); - std::vector ParseRawData(const std::string& data); - void Flush(); void MayFlush(); void Write(const char* data, size_t len); - bool Read(int fd, size_t len, char* data); - + std::string GenerateFile(int64_t seq, int64_t min_seq, int64_t max_seq); - void GetLastFile(); - void WriteSystemInfo(); - void OpenFile(const std::string& path); void FinishFile(int64_t seq); - void SwitchFile(const std::string& path); + void InsertCache(const Context& context, const Request& request); + + protected: + void GetLastFile(); void UpdateStableCheckPoint(); - std::pair>, int64_t> - GetRecoveryFiles(int64_t ckpt); + void Flush(); + + void AppendData(const std::string& data); + bool Read(int fd, size_t len, char* data); + std::pair>, int64_t> GetRecoveryFiles(int64_t ckpt); + + void SwitchFile(const std::string& path, TCallback call_back); + void OpenFile(const std::string& path); + void ReadLogsFromFiles( const std::string& path, int64_t ckpt, int file_idx, - std::function system_callback, - std::function context, - std::unique_ptr request)> - call_back); - - void InsertCache(const Context& context, const Request& request); + std::function system_callback, + TCallback call_back); - protected: + std::string file_path_; ResDBConfig config_; + // Derived class must implement these + auto ParseDataListItem(std::vector& data_list); + + template + void PerformCallback(RequestList& request_list, TCallback call_back); + + void WriteSystemInfo(); + CheckPoint* checkpoint_; std::thread ckpt_thread_; bool recovery_enabled_ = false; std::string buffer_; - std::string file_path_, base_file_path_; + + std::string base_file_path_; size_t buffer_size_ = 0; int fd_; std::mutex mutex_, data_mutex_; @@ -107,8 +116,10 @@ class Recovery { std::mutex ckpt_mutex_; std::atomic stop_; int recovery_ckpt_time_s_; - SystemInfo* system_info_; Storage* storage_; + std::function on_checkpoint_callback_; }; +#include "platform/consensus/recovery/recovery_impl.h" + } // namespace resdb diff --git a/platform/consensus/recovery/recovery.cpp b/platform/consensus/recovery/recovery_impl.h similarity index 59% rename from platform/consensus/recovery/recovery.cpp rename to platform/consensus/recovery/recovery_impl.h index 9e37f6eb70..463dc905e5 100644 --- a/platform/consensus/recovery/recovery.cpp +++ b/platform/consensus/recovery/recovery_impl.h @@ -17,29 +17,14 @@ * under the License. */ -#include "platform/consensus/recovery/recovery.h" - -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "common/utils/utils.h" - -namespace resdb { - -Recovery::Recovery(const ResDBConfig& config, CheckPoint* checkpoint, - SystemInfo* system_info, Storage* storage) +template +RecoveryBase::RecoveryBase( + const ResDBConfig& config, CheckPoint* checkpoint, Storage* storage, + std::function on_checkpoint) : config_(config), checkpoint_(checkpoint), - system_info_(system_info), - storage_(storage) { + storage_(storage), + on_checkpoint_callback_(on_checkpoint) { recovery_enabled_ = config_.GetConfigData().recovery_enabled(); file_path_ = config_.GetConfigData().recovery_path(); if (file_path_.empty()) { @@ -80,19 +65,10 @@ Recovery::Recovery(const ResDBConfig& config, CheckPoint* checkpoint, fd_ = -1; stop_ = false; - Init(); -} - -void Recovery::Init() { - LOG(ERROR) << " init"; - GetLastFile(); - SwitchFile(file_path_); - LOG(ERROR) << " init done"; - - ckpt_thread_ = std::thread(&Recovery::UpdateStableCheckPoint, this); } -Recovery::~Recovery() { +template +RecoveryBase::~RecoveryBase() { if (recovery_enabled_ == false) { return; } @@ -104,11 +80,19 @@ Recovery::~Recovery() { } } -int64_t Recovery::GetMaxSeq() { return max_seq_; } +template +int64_t RecoveryBase::GetMaxSeq() { + return max_seq_; +} -int64_t Recovery::GetMinSeq() { return min_seq_; } +template +int64_t RecoveryBase::GetMinSeq() { + return min_seq_; +} -void Recovery::UpdateStableCheckPoint() { +template +void RecoveryBase::UpdateStableCheckPoint() { if (checkpoint_ == nullptr) { return; } @@ -124,7 +108,8 @@ void Recovery::UpdateStableCheckPoint() { } } -void Recovery::GetLastFile() { +template +void RecoveryBase::GetLastFile() { std::string dir = std::filesystem::path(file_path_).parent_path(); last_ckpt_ = -1; uint64_t m_time_s = 0; @@ -162,8 +147,9 @@ void Recovery::GetLastFile() { } } -std::string Recovery::GenerateFile(int64_t seq, int64_t min_seq, - int64_t max_seq) { +template +std::string RecoveryBase::GenerateFile( + int64_t seq, int64_t min_seq, int64_t max_seq) { std::string dir = std::filesystem::path(file_path_).parent_path(); std::string file_name = std::filesystem::path(base_file_path_).stem(); int64_t time = GetCurrentTime(); @@ -175,133 +161,54 @@ std::string Recovery::GenerateFile(int64_t seq, int64_t min_seq, return dir + "/" + file_name + "." + ext; } -void Recovery::FinishFile(int64_t seq) { - std::unique_lock lk(mutex_); - Flush(); - if (storage_) { - if (!storage_->Flush()) { - return; +template +void RecoveryBase::FinishFile( + int64_t seq) { + { + std::unique_lock lk(mutex_); + Flush(); + if (storage_) { + if (!storage_->Flush(true)) { + return; + } } - } - std::string new_file_path = GenerateFile(seq, min_seq_, max_seq_); - close(fd_); - - min_seq_ = -1; - max_seq_ = -1; - - std::rename(file_path_.c_str(), new_file_path.c_str()); - - LOG(INFO) << "rename:" << file_path_ << " to:" << new_file_path; - std::string next_file_path = GenerateFile(seq, -1, -1); - file_path_ = next_file_path; - - OpenFile(file_path_); -} - -void Recovery::SwitchFile(const std::string& file_path) { - std::unique_lock lk(mutex_); - - min_seq_ = -1; - max_seq_ = -1; - - ReadLogsFromFiles( - file_path, 0, 0, [&](const SystemInfoData& data) {}, - [&](std::unique_ptr context, std::unique_ptr request) { - min_seq_ == -1 - ? min_seq_ = request->seq() - : std::min(min_seq_, static_cast(request->seq())); - max_seq_ = std::max(max_seq_, static_cast(request->seq())); - }); - - OpenFile(file_path); - LOG(INFO) << "switch to file:" << file_path << " seq:" - << "[" << min_seq_ << "," << max_seq_ << "]"; -} - -void Recovery::OpenFile(const std::string& path) { - if (fd_ >= 0) { + std::string new_file_path = GenerateFile(seq, min_seq_, max_seq_); close(fd_); - } - fd_ = open(path.c_str(), O_CREAT | O_WRONLY, 0666); - if (fd_ < 0) { - LOG(ERROR) << "open file fail:" << path << " error:" << strerror(errno); - } - int pos = lseek(fd_, 0, SEEK_END); - LOG(INFO) << "file path:" << path << " len:" << pos << " fd:" << fd_; + min_seq_ = -1; + max_seq_ = -1; - if (pos == 0) { - WriteSystemInfo(); - } - - lseek(fd_, 0, SEEK_END); - LOG(ERROR) << "open file:" << path << " pos:" << lseek(fd_, 0, SEEK_CUR) - << " fd:" << fd_; - assert(fd_ >= 0); -} + std::rename(file_path_.c_str(), new_file_path.c_str()); -void Recovery::WriteSystemInfo() { - int view = system_info_->GetCurrentView(); - int primary_id = system_info_->GetPrimaryId(); - LOG(ERROR) << "write system info:" << primary_id << " view:" << view; - SystemInfoData data; - data.set_view(view); - data.set_primary_id(primary_id); + std::string dir_path = + std::filesystem::path(file_path_).parent_path().string(); + int dir_fd = open(dir_path.c_str(), O_RDONLY); + fsync(dir_fd); + close(dir_fd); - std::string data_str; - data.SerializeToString(&data_str); + LOG(INFO) << "rename:" << file_path_ << " to:" << new_file_path; + std::string next_file_path = GenerateFile(seq, -1, -1); + file_path_ = next_file_path; - AppendData(data_str); - Flush(); -} - -void Recovery::AddRequest(const Context* context, const Request* request) { - if (recovery_enabled_ == false) { - return; - } - switch (request->type()) { - case Request::TYPE_PRE_PREPARE: - case Request::TYPE_PREPARE: - case Request::TYPE_COMMIT: - case Request::TYPE_NEWVIEW: - return WriteLog(context, request); - default: - break; - } -} - -void Recovery::WriteLog(const Context* context, const Request* request) { - std::string data; - if (request) { - request->SerializeToString(&data); + OpenFile(file_path_); } - std::string sig; - if (context) { - context->signature.SerializeToString(&sig); + if (on_checkpoint_callback_) { + on_checkpoint_callback_(seq); } - - std::unique_lock lk(mutex_); - min_seq_ = min_seq_ == -1 - ? request->seq() - : std::min(min_seq_, static_cast(request->seq())); - max_seq_ = std::max(max_seq_, static_cast(request->seq())); - AppendData(data); - AppendData(sig); - - Flush(); } -void Recovery::AppendData(const std::string& data) { +template +void RecoveryBase::AppendData( + const std::string& data) { size_t len = data.size(); buffer_.append(reinterpret_cast(&len), sizeof(len)); buffer_.append(data); } -std::vector> Recovery::ParseData( +template +auto RecoveryBase::ParseData( const std::string& data) { - std::vector> request_list; - std::vector data_list; int pos = 0; while (pos < data.size()) { @@ -314,28 +221,13 @@ std::vector> Recovery::ParseData( data_list.push_back(item); } - for (size_t i = 0; i < data_list.size(); i += 2) { - std::unique_ptr recovery_data = - std::make_unique(); - recovery_data->request = std::make_unique(); - recovery_data->context = std::make_unique(); - - if (!recovery_data->request->ParseFromString(data_list[i])) { - LOG(ERROR) << "Parse from data fail"; - break; - } - - if (!recovery_data->context->signature.ParseFromString(data_list[i + 1])) { - LOG(ERROR) << "Parse from data fail"; - break; - } - - request_list.push_back(std::move(recovery_data)); - } - return request_list; + return static_cast(this)->ParseDataListItem(data_list); } -std::vector Recovery::ParseRawData(const std::string& data) { +template +std::vector +RecoveryBase::ParseRawData( + const std::string& data) { std::vector data_list; int pos = 0; while (pos < data.size()) { @@ -350,13 +242,15 @@ std::vector Recovery::ParseRawData(const std::string& data) { return data_list; } -void Recovery::MayFlush() { +template +void RecoveryBase::MayFlush() { if (buffer_.size() > buffer_size_) { Flush(); } } -void Recovery::Flush() { +template +void RecoveryBase::Flush() { size_t len = buffer_.size(); if (len == 0) { return; @@ -368,16 +262,22 @@ void Recovery::Flush() { fsync(fd_); } -void Recovery::Write(const char* data, size_t len) { +template +void RecoveryBase::Write(const char* data, + size_t len) { int pos = 0; while (len > 0) { int write_len = write(fd_, data + pos, len); + if (write_len <= 0) break; len -= write_len; pos += write_len; } } -bool Recovery::Read(int fd, size_t len, char* data) { +template +bool RecoveryBase::Read(int fd, + size_t len, + char* data) { int pos = 0; while (len > 0) { int read_len = read(fd, data + pos, len); @@ -390,8 +290,10 @@ bool Recovery::Read(int fd, size_t len, char* data) { return true; } +template std::pair>, int64_t> -Recovery::GetRecoveryFiles(int64_t ckpt) { +RecoveryBase::GetRecoveryFiles( + int64_t ckpt) { std::string dir = std::filesystem::path(file_path_).parent_path(); int64_t last_ckpt = 0; for (const auto& entry : std::filesystem::directory_iterator(dir)) { @@ -437,24 +339,72 @@ Recovery::GetRecoveryFiles(int64_t ckpt) { } sort(e_list.begin(), e_list.end()); - list.push_back(e_list.back()); + if (!e_list.empty()) { + list.push_back(e_list.back()); + } sort(list.begin(), list.end()); return std::make_pair(list, last_ckpt); } -void Recovery::ReadLogs( - std::function system_callback, - std::function context, - std::unique_ptr request)> - call_back, - std::function set_start_point) { +template +std::vector> +RecoveryBase::GetSortedRecoveryFiles( + uint64_t need_min_seq, uint64_t need_max_seq) { + std::string dir = std::filesystem::path(file_path_).parent_path(); + + std::vector> list; + std::vector> e_list; + + for (const auto& entry : std::filesystem::directory_iterator(dir)) { + std::string dir = std::filesystem::path(entry.path()).parent_path(); + std::string file_name = std::filesystem::path(entry.path()).stem(); + std::string ext = std::filesystem::path(entry.path()).extension(); + if (ext != ".log") continue; + int pos = file_name.rfind("_"); + + int max_seq_pos = file_name.rfind("_", pos - 1); + int64_t max_seq = + std::stoll(file_name.substr(max_seq_pos + 1, pos - max_seq_pos - 1)); + + int min_seq_pos = file_name.rfind("_", max_seq_pos - 1); + int64_t min_seq = std::stoll( + file_name.substr(min_seq_pos + 1, max_seq_pos - min_seq_pos - 1)); + + int time_pos = file_name.rfind("_", min_seq_pos - 1); + int64_t time = + std::stoll(file_name.substr(time_pos + 1, min_seq_pos - time_pos - 1)); + + // LOG(ERROR)<<" min seq:"< need_max_seq) { + continue; + } + // LOG(ERROR)<<" get min seq:"< +void RecoveryBase::ReadLogs( + std::function system_callback, + TCallback call_back, std::function set_start_point) { if (recovery_enabled_ == false) { return; } int64_t storage_ckpt = 0; - if(storage_) { + if (storage_) { storage_ckpt = storage_->GetLastCheckpoint(); } std::unique_lock lk(mutex_); @@ -470,12 +420,49 @@ void Recovery::ReadLogs( } } -void Recovery::ReadLogsFromFiles( +template +void RecoveryBase::SwitchFile( + const std::string& file_path, TCallback call_back) { + std::unique_lock lk(mutex_); + + min_seq_ = -1; + max_seq_ = -1; + ReadLogsFromFiles( + file_path, 0, 0, [&](const TSystemInfoData& data) {}, call_back); + OpenFile(file_path); + LOG(INFO) << "switch to file:" << file_path << " seq:" + << "[" << min_seq_ << "," << max_seq_ << "]"; +} + +template +void RecoveryBase::OpenFile( + const std::string& path) { + if (fd_ >= 0) { + close(fd_); + } + fd_ = open(path.c_str(), O_CREAT | O_WRONLY, 0666); + if (fd_ < 0) { + LOG(ERROR) << "open file fail:" << path << " error:" << strerror(errno); + } + + int pos = lseek(fd_, 0, SEEK_END); + LOG(INFO) << "file path:" << path << " len:" << pos << " fd:" << fd_; + + if (pos == 0) { + static_cast(this)->WriteSystemInfo(); + } + + lseek(fd_, 0, SEEK_END); + LOG(ERROR) << "open file:" << path << " pos:" << lseek(fd_, 0, SEEK_CUR) + << " fd:" << fd_; + assert(fd_ >= 0); +} + +template +void RecoveryBase::ReadLogsFromFiles( const std::string& path, int64_t ckpt, int file_idx, - std::function system_callback, - std::function context, - std::unique_ptr request)> - call_back) { + std::function system_callback, + TCallback call_back) { int fd = open(path.c_str(), O_CREAT | O_RDONLY, 0666); if (fd < 0) { LOG(ERROR) << " open file fail:" << path; @@ -484,28 +471,9 @@ void Recovery::ReadLogsFromFiles( assert(fd >= 0); size_t data_len = 0; - Read(fd, sizeof(data_len), reinterpret_cast(&data_len)); - { - std::string data; - char* buf = new char[data_len]; - if (!Read(fd, data_len, buf)) { - LOG(ERROR) << "Read system info fail"; - return; - } - data = std::string(buf, data_len); - delete buf; - std::vector data_list = ParseRawData(data); - - SystemInfoData info; - if (data_list.empty() || !info.ParseFromString(data_list[0])) { - LOG(ERROR) << "parse info fail:" << data.size(); - return; - } - LOG(ERROR) << "read system info:" << info.DebugString(); - system_callback(info); - } + static_cast(this)->HandleSystemInfo(fd, system_callback); - std::vector> request_list; + decltype(ParseData(std::string{})) request_list; while (Read(fd, sizeof(data_len), reinterpret_cast(&data_len))) { std::string data; @@ -517,9 +485,8 @@ void Recovery::ReadLogsFromFiles( data = std::string(buf, data_len); delete buf; - std::vector> list = ParseData(data); + auto list = ParseData(data); if (list.size() == 0) { - request_list.clear(); break; } for (auto& l : list) { @@ -527,106 +494,20 @@ void Recovery::ReadLogsFromFiles( } } if (request_list.size() == 0) { - ftruncate(fd, 0); - } - uint64_t max_seq = 0; - for (std::unique_ptr& recovery_data : request_list) { - // LOG(ERROR)<<" ckpt :"<request->seq()<<" - // type:"<request->type(); - if (ckpt < recovery_data->request->seq() || - recovery_data->request->type() == Request::TYPE_NEWVIEW) { - recovery_data->request->set_is_recovery(true); - max_seq = recovery_data->request->seq(); - call_back(std::move(recovery_data->context), - std::move(recovery_data->request)); + LOG(ERROR) << " Request list is empty"; + close(fd); + fd = open(path.c_str(), O_RDWR); + if (fd < 0) { + LOG(ERROR) << " open file as O_RDWR to truncate fail:" << path; } - } - - LOG(ERROR) << "read log from files:" << path << " done" - << " recovery max seq:" << max_seq; - - close(fd); -} - -int Recovery::GetData(const RecoveryRequest& request, - RecoveryResponse& response) { - auto res = GetDataFromRecoveryFiles(request.min_seq(), request.max_seq()); - - for (const auto& it : res) { - for (const auto& req : it.second) { - *response.add_signature() = req.first->signature; - *response.add_request() = *req.second; + if (ftruncate(fd, 0) != 0) { + LOG(ERROR) << " Failed to truncate file"; } - } - return 0; -} - -std::map< - uint64_t, - std::vector, std::unique_ptr>>> -Recovery::GetDataFromRecoveryFiles(uint64_t need_min_seq, - uint64_t need_max_seq) { - std::string dir = std::filesystem::path(file_path_).parent_path(); - - std::vector> list; - std::vector> e_list; - - for (const auto& entry : std::filesystem::directory_iterator(dir)) { - std::string dir = std::filesystem::path(entry.path()).parent_path(); - std::string file_name = std::filesystem::path(entry.path()).stem(); - std::string ext = std::filesystem::path(entry.path()).extension(); - if (ext != ".log") continue; - int pos = file_name.rfind("_"); - - int max_seq_pos = file_name.rfind("_", pos - 1); - int64_t max_seq = - std::stoll(file_name.substr(max_seq_pos + 1, pos - max_seq_pos - 1)); - - int min_seq_pos = file_name.rfind("_", max_seq_pos - 1); - int64_t min_seq = std::stoll( - file_name.substr(min_seq_pos + 1, max_seq_pos - min_seq_pos - 1)); - - int time_pos = file_name.rfind("_", min_seq_pos - 1); - int64_t time = - std::stoll(file_name.substr(time_pos + 1, min_seq_pos - time_pos - 1)); - - // LOG(ERROR)<<" min seq:"< need_max_seq) { - continue; - } - // LOG(ERROR)<<" get min seq:"<(this)->PerformCallback(request_list, call_back, ckpt); - std::map, - std::unique_ptr>>> - res; - for (const auto& path : list) { - ReadLogsFromFiles( - path.second, need_min_seq - 1, 0, [&](const SystemInfoData& data) {}, - [&](std::unique_ptr context, - std::unique_ptr request) { - // LOG(ERROR) << "check get data from recovery file seq:" - // << request->seq(); - if (request->seq() >= need_min_seq && - request->seq() <= need_max_seq) { - LOG(ERROR) << "get data from recovery file seq:" << request->seq(); - res[request->seq()].push_back( - std::make_pair(std::move(context), std::move(request))); - } - }); - } - - return res; + LOG(ERROR) << "read log from files:" << path << " done"; + close(fd); } - -} // namespace resdb diff --git a/platform/consensus/recovery/recovery_test.cpp b/platform/consensus/recovery/recovery_test.cpp index f6aeba4734..9317589ef2 100644 --- a/platform/consensus/recovery/recovery_test.cpp +++ b/platform/consensus/recovery/recovery_test.cpp @@ -17,8 +17,6 @@ * under the License. */ -#include "platform/consensus/recovery/recovery.h" - #include #include #include @@ -30,6 +28,7 @@ #include "common/test/test_macros.h" #include "platform/consensus/checkpoint/mock_checkpoint.h" #include "platform/consensus/ordering/common/transaction_utils.h" +#include "platform/consensus/recovery/pbft_recovery.h" namespace resdb { namespace { @@ -87,7 +86,7 @@ TEST_F(RecoveryTest, ReadLog) { }; { - Recovery recovery(config_, &checkpoint_, &system_info_, nullptr); + PBFTRecovery recovery(config_, &checkpoint_, &system_info_, nullptr); for (int t : types) { std::unique_ptr request = @@ -98,12 +97,16 @@ TEST_F(RecoveryTest, ReadLog) { } { std::vector list; - Recovery recovery(config_, &checkpoint_, &system_info_, nullptr); - recovery.ReadLogs( - [&](const SystemInfoData &data) {}, - [&](std::unique_ptr context, - std::unique_ptr request) { list.push_back(*request); }, - nullptr); + PBFTRecovery recovery(config_, &checkpoint_, &system_info_, nullptr); + + std::function, std::unique_ptr)> + call_back = [&](std::unique_ptr context, + std::unique_ptr request) { + list.push_back(*request); + // LOG(ERROR) << "call back:" << request->seq(); + }; + + recovery.ReadLogs([&](const SystemInfoData &data) {}, call_back, nullptr); EXPECT_EQ(list.size(), expected_types.size()); @@ -127,7 +130,7 @@ TEST_F(RecoveryTest, ReadLog_FlushOnce) { }; { - Recovery recovery(config, &checkpoint_, &system_info_, nullptr); + PBFTRecovery recovery(config, &checkpoint_, &system_info_, nullptr); for (int t : types) { std::unique_ptr request = @@ -138,13 +141,14 @@ TEST_F(RecoveryTest, ReadLog_FlushOnce) { } { std::vector list; - Recovery recovery(config, &checkpoint_, &system_info_, nullptr); - recovery.ReadLogs([&](const SystemInfoData &data) {}, - [&](std::unique_ptr context, - std::unique_ptr request) { - list.push_back(*request); - }, - nullptr); + PBFTRecovery recovery(config, &checkpoint_, &system_info_, nullptr); + + std::function, std::unique_ptr)> + call_back = + [&](std::unique_ptr context, + std::unique_ptr request) { list.push_back(*request); }; + + recovery.ReadLogs([&](const SystemInfoData &data) {}, call_back, nullptr); EXPECT_EQ(list.size(), expected_types.size()); @@ -179,7 +183,7 @@ TEST_F(RecoveryTest, CheckPoint) { })); { - Recovery recovery(config, &checkpoint_, &system_info_, nullptr); + PBFTRecovery recovery(config, &checkpoint_, &system_info_, nullptr); for (int i = 1; i < 10; ++i) { for (int t : types) { @@ -204,14 +208,16 @@ TEST_F(RecoveryTest, CheckPoint) { EXPECT_EQ(log_list.size(), 2); { std::vector list; - Recovery recovery(config, &checkpoint_, &system_info_, nullptr); - recovery.ReadLogs([&](const SystemInfoData &data) {}, - [&](std::unique_ptr context, - std::unique_ptr request) { - list.push_back(*request); - // LOG(ERROR)<<"call back:"<seq(); - }, - nullptr); + PBFTRecovery recovery(config, &checkpoint_, &system_info_, nullptr); + + std::function, std::unique_ptr)> + call_back = [&](std::unique_ptr context, + std::unique_ptr request) { + list.push_back(*request); + // LOG(ERROR) << "call back:" << request->seq(); + }; + + recovery.ReadLogs([&](const SystemInfoData &data) {}, call_back, nullptr); EXPECT_EQ(list.size(), types.size() * 14); @@ -257,7 +263,7 @@ TEST_F(RecoveryTest, CheckPoint2) { })); { - Recovery recovery(config, &checkpoint_, &system_info_, &storage); + PBFTRecovery recovery(config, &checkpoint_, &system_info_, &storage); for (int i = 1; i < 10; ++i) { for (int t : types) { @@ -282,14 +288,16 @@ TEST_F(RecoveryTest, CheckPoint2) { EXPECT_EQ(log_list.size(), 2); { std::vector list; - Recovery recovery(config, &checkpoint_, &system_info_, &storage); - recovery.ReadLogs([&](const SystemInfoData &data) {}, - [&](std::unique_ptr context, - std::unique_ptr request) { - list.push_back(*request); - // LOG(ERROR)<<"call back:"<seq(); - }, - nullptr); + PBFTRecovery recovery(config, &checkpoint_, &system_info_, &storage); + + std::function, std::unique_ptr)> + call_back = [&](std::unique_ptr context, + std::unique_ptr request) { + list.push_back(*request); + // LOG(ERROR) << "call back:" << request->seq(); + }; + + recovery.ReadLogs([&](const SystemInfoData &data) {}, call_back, nullptr); EXPECT_EQ(list.size(), types.size() * 14); @@ -320,14 +328,16 @@ TEST_F(RecoveryTest, CheckPoint2) { { std::vector list; - Recovery recovery(config, &checkpoint_, &system_info_, &storage); - recovery.ReadLogs([&](const SystemInfoData &data) {}, - [&](std::unique_ptr context, - std::unique_ptr request) { - list.push_back(*request); - // LOG(ERROR)<<"call back:"<seq(); - }, - nullptr); + PBFTRecovery recovery(config, &checkpoint_, &system_info_, &storage); + + std::function, std::unique_ptr)> + call_back = [&](std::unique_ptr context, + std::unique_ptr request) { + list.push_back(*request); + // LOG(ERROR) << "call back:" << request->seq(); + }; + + recovery.ReadLogs([&](const SystemInfoData &data) {}, call_back, nullptr); EXPECT_EQ(list.size(), types.size() * 9); @@ -375,7 +385,7 @@ TEST_F(RecoveryTest, SystemInfo) { })); { - Recovery recovery(config, &checkpoint_, &system_info_, &storage); + PBFTRecovery recovery(config, &checkpoint_, &system_info_, &storage); system_info_.SetCurrentView(2); system_info_.SetPrimary(2); @@ -403,14 +413,17 @@ TEST_F(RecoveryTest, SystemInfo) { { std::vector list; SystemInfoData data; - Recovery recovery(config, &checkpoint_, &system_info_, &storage); + PBFTRecovery recovery(config, &checkpoint_, &system_info_, &storage); + + std::function, std::unique_ptr)> + call_back = [&](std::unique_ptr context, + std::unique_ptr request) { + list.push_back(*request); + // LOG(ERROR) << "call back:" << request->seq(); + }; + recovery.ReadLogs([&](const SystemInfoData &r_data) { data = r_data; }, - [&](std::unique_ptr context, - std::unique_ptr request) { - list.push_back(*request); - // LOG(ERROR)<<"call back:"<seq(); - }, - nullptr); + call_back, nullptr); EXPECT_EQ(list.size(), types.size() * 14); @@ -442,14 +455,17 @@ TEST_F(RecoveryTest, SystemInfo) { { std::vector list; SystemInfoData data; - Recovery recovery(config, &checkpoint_, &system_info_, &storage); + PBFTRecovery recovery(config, &checkpoint_, &system_info_, &storage); + + std::function, std::unique_ptr)> + call_back = [&](std::unique_ptr context, + std::unique_ptr request) { + list.push_back(*request); + // LOG(ERROR) << "call back:" << request->seq(); + }; + recovery.ReadLogs([&](const SystemInfoData &r_data) { data = r_data; }, - [&](std::unique_ptr context, - std::unique_ptr request) { - list.push_back(*request); - // LOG(ERROR)<<"call back:"<seq(); - }, - nullptr); + call_back, nullptr); EXPECT_EQ(data.view(), 2); EXPECT_EQ(data.primary_id(), 2); diff --git a/raft_performance.sh b/raft_performance.sh new file mode 100755 index 0000000000..c688e75e3d --- /dev/null +++ b/raft_performance.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Move into the deploy directory +cd ~/incubator-resilientdb/scripts/deploy + +# Run the performance script with the config file +./performance_local/raft_performance.sh config/kv_performance_server_local.conf >out_raft.txt 2>&1 diff --git a/scripts/deploy/config/pbft.config b/scripts/deploy/config/pbft.config index 1de013abf3..e13d03df63 100644 --- a/scripts/deploy/config/pbft.config +++ b/scripts/deploy/config/pbft.config @@ -22,8 +22,8 @@ "enable_viewchange": true, "recovery_enabled": true, "max_client_complaint_num":10, - "max_process_txn": 2048, - "worker_num": 2, - "input_worker_num": 1, - "output_worker_num": 10 + "max_process_txn": 64, + "worker_num": 16, + "input_worker_num": 5, + "output_worker_num": 5 } diff --git a/scripts/deploy/config/raft.config b/scripts/deploy/config/raft.config new file mode 100644 index 0000000000..d24bb8493b --- /dev/null +++ b/scripts/deploy/config/raft.config @@ -0,0 +1,11 @@ +{ + "clientBatchNum": 100, + "enable_viewchange": true, + "recovery_enabled": true, + "not_need_signature": true, + "max_client_complaint_num":10, + "max_process_txn": 64, + "worker_num": 16, + "input_worker_num": 5, + "output_worker_num": 5 +} diff --git a/scripts/deploy/performance/calculate_result.py b/scripts/deploy/performance/calculate_result.py index f6892d2685..5852c3d472 100644 --- a/scripts/deploy/performance/calculate_result.py +++ b/scripts/deploy/performance/calculate_result.py @@ -53,7 +53,11 @@ def cal_lat(lat): lat_sum.append(v) print("max latency:",lat_max) - print("average latency:",sum(lat_sum)/len(lat_sum)) + if not len(lat_sum): + average_latency = 0 + else: + average_latency = sum(lat_sum)/len(lat_sum) + print("average latency:", average_latency) if __name__ == '__main__': files = sys.argv[1:] diff --git a/scripts/deploy/performance/raft_performance.sh b/scripts/deploy/performance/raft_performance.sh new file mode 100755 index 0000000000..9603df099f --- /dev/null +++ b/scripts/deploy/performance/raft_performance.sh @@ -0,0 +1,23 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +export server=//benchmark/protocols/raft:kv_server_performance +export TEMPLATE_PATH=$PWD/config/raft.config + +./performance/run_performance.sh $* diff --git a/scripts/deploy/performance/run_performance.sh b/scripts/deploy/performance/run_performance.sh index 7ef2b8a798..ca3f49a132 100755 --- a/scripts/deploy/performance/run_performance.sh +++ b/scripts/deploy/performance/run_performance.sh @@ -54,6 +54,6 @@ done python3 performance/calculate_result.py `ls result_*_log` > results.log -rm -rf result_*_log +#rm -rf result_*_log echo "save result to results.log" cat results.log diff --git a/scripts/deploy/performance_local/pbft_performance.sh b/scripts/deploy/performance_local/pbft_performance.sh index 003b9787f8..5b337fe318 100755 --- a/scripts/deploy/performance_local/pbft_performance.sh +++ b/scripts/deploy/performance_local/pbft_performance.sh @@ -17,8 +17,12 @@ # under the License. # -export server=//benchmark/protocols/pbft:kv_server_performance -#export TEMPLATE_PATH=$PWD/config/pbft.config +protocol=pbft +export server=//benchmark/protocols/$protocol:kv_server_performance +export service_tools=//benchmark/protocols/$protocol:kv_service_tools +export TEMPLATE_PATH=$PWD/config/$protocol.config export performance=true +#export TEMPLATE_PATH=$PWD/config/pbft.config + ./performance_local/run_performance.sh $* diff --git a/scripts/deploy/performance_local/poe_performance.sh b/scripts/deploy/performance_local/poe_performance.sh index fd23e077a1..bbb48846fe 100755 --- a/scripts/deploy/performance_local/poe_performance.sh +++ b/scripts/deploy/performance_local/poe_performance.sh @@ -11,13 +11,16 @@ # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANYß # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # -export server=//benchmark/protocols/poe:kv_server_performance -export TEMPLATE_PATH=$PWD/config/poe.config +protocol=poe +export server=//benchmark/protocols/$protocol:kv_server_performance +export service_tools=//benchmark/protocols/pbft:kv_service_tools +export TEMPLATE_PATH=$PWD/config/$protocol.config +export performance=true ./performance_local/run_performance.sh $* diff --git a/scripts/deploy/performance_local/raft_performance.sh b/scripts/deploy/performance_local/raft_performance.sh new file mode 100755 index 0000000000..def1aa30ef --- /dev/null +++ b/scripts/deploy/performance_local/raft_performance.sh @@ -0,0 +1,25 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +protocol=raft +export server=//benchmark/protocols/$protocol:kv_server_performance +export service_tools=//benchmark/protocols/$protocol:kv_service_tools +export TEMPLATE_PATH=$PWD/config/$protocol.config +export performance=true + +./performance_local/run_performance.sh $* diff --git a/scripts/deploy/performance_local/run_performance.sh b/scripts/deploy/performance_local/run_performance.sh index 25cab4cf4c..558c2282c1 100755 --- a/scripts/deploy/performance_local/run_performance.sh +++ b/scripts/deploy/performance_local/run_performance.sh @@ -26,7 +26,7 @@ home_path="./" server_name=`echo "$server" | awk -F':' '{print $NF}'` server_bin=${server_name} -bazel run //benchmark/protocols/pbft:kv_service_tools -- $PWD/config_out/client.config +bazel run $service_tools -- $PWD/config_out/client.config sleep 60 @@ -55,6 +55,6 @@ done python3 performance/calculate_result.py `ls result_*_log` > results.log -rm -rf result_*_log +#rm -rf result_*_log echo "save result to results.log" cat results.log diff --git a/scripts/deploy/script/deploy_local.sh b/scripts/deploy/script/deploy_local.sh index 11145de839..778ce744a5 100755 --- a/scripts/deploy/script/deploy_local.sh +++ b/scripts/deploy/script/deploy_local.sh @@ -72,6 +72,8 @@ deploy/script/generate_config.sh ${BAZEL_WORKSPACE_PATH} ${output_key_path} ${ou # build kv server bazel build ${server} +# JIM opts for debug +#bazel build -c opt --copt=-g --strip=never ${server} if [ $? != 0 ] then @@ -144,6 +146,7 @@ do private_key="cert/node_"${idx}".key.pri" cert="cert/cert_"${idx}".cert" cd ${home_path}/${main_folder}/$idx; nohup ./${server_bin} server.config ${private_key} ${cert} ${grafna_port} > ${server_bin}.log 2>&1 & + echo "cd ${home_path}/${main_folder}/$idx; nohup ./${server_bin} server.config ${private_key} ${cert} ${grafna_port} > ${server_bin}.log 2>&1 &" ((count++)) ((idx++)) ((grafna_port++)) diff --git a/scripts/deploy/script/generate_config.sh b/scripts/deploy/script/generate_config.sh index aa3d77d71a..4fce410e0d 100755 --- a/scripts/deploy/script/generate_config.sh +++ b/scripts/deploy/script/generate_config.sh @@ -78,5 +78,27 @@ do idx=$(($idx+1)) done +#python3 ${CONFIG_TOOLS_BIN} ./client.config ./client.config.json ${TEMPLATE_PATH} +#mv client.config.json client.config + +# Rewrite client.config into RegionInfo JSON for ReadConfig() +python3 - <<'PY' +import json + +path = "client.config" +replicas = [] +with open(path) as f: + for line in f: + line = line.strip() + if not line: + continue + i, ip, port = line.split() + replicas.append({"id": int(i), "ip": ip, "port": int(port)}) + +with open(path, "w") as out: + json.dump({"replicaInfo": replicas}, out) + out.write("\n") +PY + python3 ${CONFIG_TOOLS_BIN} ./server.config ./server.config.json ${TEMPLATE_PATH} mv server.config.json server.config