2424#include < algorithm>
2525#include < chrono>
2626#include < iterator>
27+ #include < mutex>
2728#include < random>
29+ #include < shared_mutex>
2830#include < sstream>
2931
3032#include " BinaryProtoLookupService.h"
@@ -74,8 +76,6 @@ std::string generateRandomName() {
7476 return randomName;
7577}
7678
77- typedef std::unique_lock<std::mutex> Lock;
78-
7979typedef std::vector<std::string> StringList;
8080
8181namespace {
@@ -157,19 +157,26 @@ ExecutorServiceProviderPtr ClientImpl::getPartitionListenerExecutorProvider() {
157157}
158158
159159LookupServicePtr ClientImpl::getLookup (const std::string& redirectedClusterURI) {
160- Lock lock (mutex_);
160+ std::shared_lock readLock (mutex_);
161161 if (redirectedClusterURI.empty ()) {
162162 return lookupServicePtr_;
163163 }
164164
165- auto it = redirectedClusterLookupServicePtrs_.find (redirectedClusterURI);
166- if (it == redirectedClusterLookupServicePtrs_.end ()) {
167- auto lookup = createLookup (redirectedClusterURI);
168- redirectedClusterLookupServicePtrs_.emplace (redirectedClusterURI, lookup);
169- return lookup;
165+ if (auto it = redirectedClusterLookupServicePtrs_.find (redirectedClusterURI);
166+ it != redirectedClusterLookupServicePtrs_.end ()) {
167+ return it->second ;
170168 }
169+ readLock.unlock ();
171170
172- return it->second ;
171+ std::unique_lock writeLock (mutex_);
172+ // Double check in case another thread acquires the lock and inserts a pair first
173+ if (auto it = redirectedClusterLookupServicePtrs_.find (redirectedClusterURI);
174+ it != redirectedClusterLookupServicePtrs_.end ()) {
175+ return it->second ;
176+ }
177+ auto lookup = createLookup (redirectedClusterURI);
178+ redirectedClusterLookupServicePtrs_.emplace (redirectedClusterURI, lookup);
179+ return lookup;
173180}
174181
175182void ClientImpl::createProducerAsync (const std::string& topic, const ProducerConfiguration& conf,
@@ -179,7 +186,7 @@ void ClientImpl::createProducerAsync(const std::string& topic, const ProducerCon
179186 }
180187 TopicNamePtr topicName;
181188 {
182- Lock lock (mutex_);
189+ std::shared_lock lock (mutex_);
183190 if (state_ != Open) {
184191 lock.unlock ();
185192 callback (ResultAlreadyClosed, Producer ());
@@ -267,7 +274,7 @@ void ClientImpl::createReaderAsync(const std::string& topic, const MessageId& st
267274 const ReaderConfiguration& conf, const ReaderCallback& callback) {
268275 TopicNamePtr topicName;
269276 {
270- Lock lock (mutex_);
277+ std::shared_lock lock (mutex_);
271278 if (state_ != Open) {
272279 lock.unlock ();
273280 callback (ResultAlreadyClosed, Reader ());
@@ -289,7 +296,7 @@ void ClientImpl::createTableViewAsync(const std::string& topic, const TableViewC
289296 const TableViewCallback& callback) {
290297 TopicNamePtr topicName;
291298 {
292- Lock lock (mutex_);
299+ std::shared_lock lock (mutex_);
293300 if (state_ != Open) {
294301 lock.unlock ();
295302 callback (ResultAlreadyClosed, TableView ());
@@ -355,7 +362,7 @@ void ClientImpl::subscribeWithRegexAsync(const std::string& regexPattern, const
355362 const SubscribeCallback& callback) {
356363 TopicNamePtr topicNamePtr = TopicName::get (regexPattern);
357364
358- Lock lock (mutex_);
365+ std::shared_lock lock (mutex_);
359366 if (state_ != Open) {
360367 lock.unlock ();
361368 callback (ResultAlreadyClosed, Consumer ());
@@ -441,7 +448,7 @@ void ClientImpl::subscribeAsync(const std::vector<std::string>& originalTopics,
441448 auto it = std::unique (topics.begin (), topics.end ());
442449 auto newSize = std::distance (topics.begin (), it);
443450 topics.resize (newSize);
444- Lock lock (mutex_);
451+ std::shared_lock lock (mutex_);
445452 if (state_ != Open) {
446453 lock.unlock ();
447454 callback (ResultAlreadyClosed, Consumer ());
@@ -477,7 +484,7 @@ void ClientImpl::subscribeAsync(const std::string& topic, const std::string& sub
477484 const ConsumerConfiguration& conf, const SubscribeCallback& callback) {
478485 TopicNamePtr topicName;
479486 {
480- Lock lock (mutex_);
487+ std::shared_lock lock (mutex_);
481488 if (state_ != Open) {
482489 lock.unlock ();
483490 callback (ResultAlreadyClosed, Consumer ());
@@ -662,7 +669,7 @@ void ClientImpl::handleGetPartitions(Result result, const LookupDataResultPtr& p
662669void ClientImpl::getPartitionsForTopicAsync (const std::string& topic, const GetPartitionsCallback& callback) {
663670 TopicNamePtr topicName;
664671 {
665- Lock lock (mutex_);
672+ std::shared_lock lock (mutex_);
666673 if (state_ != Open) {
667674 lock.unlock ();
668675 callback (ResultAlreadyClosed, StringList ());
@@ -679,7 +686,9 @@ void ClientImpl::getPartitionsForTopicAsync(const std::string& topic, const GetP
679686}
680687
681688void ClientImpl::closeAsync (const CloseCallback& callback) {
689+ std::unique_lock lock (mutex_);
682690 if (state_ != Open) {
691+ lock.unlock ();
683692 if (callback) {
684693 callback (ResultAlreadyClosed);
685694 }
@@ -689,10 +698,12 @@ void ClientImpl::closeAsync(const CloseCallback& callback) {
689698 state_ = Closing;
690699
691700 memoryLimitController_.close ();
692- getLookup () ->close ();
701+ lookupServicePtr_ ->close ();
693702 for (const auto & it : redirectedClusterLookupServicePtrs_) {
694703 it.second ->close ();
695704 }
705+ redirectedClusterLookupServicePtrs_.clear ();
706+ lock.unlock ();
696707
697708 auto producers = producers_.move ();
698709 auto consumers = consumers_.move ();
@@ -741,7 +752,7 @@ void ClientImpl::handleClose(Result result, const SharedInt& numberOfOpenHandler
741752 --(*numberOfOpenHandlers);
742753 }
743754 if (*numberOfOpenHandlers == 0 ) {
744- Lock lock (mutex_);
755+ std::unique_lock lock (mutex_);
745756 if (state_ == Closed) {
746757 LOG_DEBUG (" Client is already shutting down, possible race condition in handleClose" );
747758 return ;
@@ -821,12 +832,12 @@ void ClientImpl::shutdown() {
821832}
822833
823834uint64_t ClientImpl::newProducerId () {
824- Lock lock (mutex_);
835+ std::shared_lock lock (mutex_);
825836 return producerIdGenerator_++;
826837}
827838
828839uint64_t ClientImpl::newConsumerId () {
829- Lock lock (mutex_);
840+ std::shared_lock lock (mutex_);
830841 return consumerIdGenerator_++;
831842}
832843
@@ -870,51 +881,40 @@ std::chrono::nanoseconds ClientImpl::getOperationTimeout(const ClientConfigurati
870881}
871882
872883void ClientImpl::updateServiceInfo (const ServiceInfo& serviceInfo) {
873- LookupServicePtr oldLookupServicePtr;
874- std::unordered_map<std::string, LookupServicePtr> oldRedirectedLookupServicePtrs;
875-
876- {
877- Lock lock (mutex_);
878- if (state_ != Open) {
879- LOG_ERROR (" Client is not open, cannot update connection info" );
880- return ;
881- }
882-
883- if (serviceInfo.authentication .has_value () && *serviceInfo.authentication ) {
884- clientConfiguration_.setAuth (*serviceInfo.authentication );
885- } else {
886- clientConfiguration_.setAuth (AuthFactory::Disabled ());
887- }
888- if (serviceInfo.tlsTrustCertsFilePath .has_value ()) {
889- clientConfiguration_.setTlsTrustCertsFilePath (*serviceInfo.tlsTrustCertsFilePath );
890- } else {
891- clientConfiguration_.setTlsTrustCertsFilePath (" " );
892- }
893- clientConfiguration_.setUseTls (ServiceNameResolver::useTls (ServiceURI (serviceInfo.serviceUrl )));
894- serviceInfo_ = {serviceInfo.serviceUrl , toOptionalAuthentication (clientConfiguration_.getAuthPtr ()),
895- clientConfiguration_.getTlsTrustCertsFilePath ().empty ()
896- ? std::nullopt
897- : std::make_optional (clientConfiguration_.getTlsTrustCertsFilePath ())};
898-
899- oldLookupServicePtr = std::move (lookupServicePtr_);
900- oldRedirectedLookupServicePtrs = std::move (redirectedClusterLookupServicePtrs_);
901-
902- lookupServicePtr_ = createLookup (serviceInfo.serviceUrl );
903- redirectedClusterLookupServicePtrs_.clear ();
884+ std::unique_lock lock (mutex_);
885+ if (state_ != Open) {
886+ LOG_ERROR (" Client is not open, cannot update connection info" );
887+ return ;
904888 }
905889
906- if (oldLookupServicePtr) {
907- oldLookupServicePtr->close ();
890+ if (serviceInfo.authentication .has_value () && *serviceInfo.authentication ) {
891+ clientConfiguration_.setAuth (*serviceInfo.authentication );
892+ } else {
893+ clientConfiguration_.setAuth (AuthFactory::Disabled ());
908894 }
909- for (const auto & it : oldRedirectedLookupServicePtrs) {
910- it.second ->close ();
895+ if (serviceInfo.tlsTrustCertsFilePath .has_value ()) {
896+ clientConfiguration_.setTlsTrustCertsFilePath (*serviceInfo.tlsTrustCertsFilePath );
897+ } else {
898+ clientConfiguration_.setTlsTrustCertsFilePath (" " );
911899 }
900+ clientConfiguration_.setUseTls (ServiceNameResolver::useTls (ServiceURI (serviceInfo.serviceUrl )));
901+ serviceInfo_ = {serviceInfo.serviceUrl , toOptionalAuthentication (clientConfiguration_.getAuthPtr ()),
902+ clientConfiguration_.getTlsTrustCertsFilePath ().empty ()
903+ ? std::nullopt
904+ : std::make_optional (clientConfiguration_.getTlsTrustCertsFilePath ())};
912905
913906 pool_.resetConnections (clientConfiguration_.getAuthPtr (), clientConfiguration_);
907+
908+ lookupServicePtr_->close ();
909+ for (auto && it : redirectedClusterLookupServicePtrs_) {
910+ it.second ->close ();
911+ }
912+ redirectedClusterLookupServicePtrs_.clear ();
913+ lookupServicePtr_ = createLookup (serviceInfo.serviceUrl );
914914}
915915
916- ServiceInfo ClientImpl::getServiceInfo () {
917- Lock lock (mutex_);
916+ ServiceInfo ClientImpl::getServiceInfo () const {
917+ std::shared_lock lock (mutex_);
918918 return serviceInfo_;
919919}
920920
0 commit comments