diff --git a/source/extensions/filters/http/rate_limit_quota/BUILD b/source/extensions/filters/http/rate_limit_quota/BUILD index 0437fb46f8237..a1a865c9eb01f 100644 --- a/source/extensions/filters/http/rate_limit_quota/BUILD +++ b/source/extensions/filters/http/rate_limit_quota/BUILD @@ -55,6 +55,7 @@ envoy_cc_library( srcs = ["client_impl.cc"], hdrs = ["client_impl.h"], deps = [ + ":filter_persistence", ":global_client_lib", ":quota_bucket_cache", "//envoy/grpc:async_client_interface", diff --git a/source/extensions/filters/http/rate_limit_quota/client_impl.cc b/source/extensions/filters/http/rate_limit_quota/client_impl.cc index 8731678923b93..b62ec5c6a957a 100644 --- a/source/extensions/filters/http/rate_limit_quota/client_impl.cc +++ b/source/extensions/filters/http/rate_limit_quota/client_impl.cc @@ -2,6 +2,7 @@ #include #include +#include #include "envoy/type/v3/ratelimit_strategy.pb.h" #include "envoy/type/v3/token_bucket.pb.h" @@ -23,8 +24,9 @@ void LocalRateLimitClientImpl::createBucket( std::chrono::milliseconds fallback_ttl, bool initial_request_allowed) { // Intentionally crash if the local client is initialized with a null global // client or TLS slot due to a bug. - global_client_->createBucket(bucket_id, id, default_bucket_action, std::move(fallback_action), - fallback_ttl, initial_request_allowed); + tls_store_->global_client->createBucket(bucket_id, id, default_bucket_action, + std::move(fallback_action), fallback_ttl, + initial_request_allowed); } std::shared_ptr LocalRateLimitClientImpl::getBucket(size_t id) { diff --git a/source/extensions/filters/http/rate_limit_quota/client_impl.h b/source/extensions/filters/http/rate_limit_quota/client_impl.h index 70fc2df48849b..60d022a496ba6 100644 --- a/source/extensions/filters/http/rate_limit_quota/client_impl.h +++ b/source/extensions/filters/http/rate_limit_quota/client_impl.h @@ -7,6 +7,7 @@ #include "source/common/grpc/typed_async_client.h" #include "source/extensions/filters/http/common/factory_base.h" +#include "source/extensions/filters/http/rate_limit_quota/filter_persistence.h" #include "source/extensions/filters/http/rate_limit_quota/global_client_impl.h" #include "source/extensions/filters/http/rate_limit_quota/quota_bucket_cache.h" @@ -30,10 +31,8 @@ using GrpcAsyncClient = class LocalRateLimitClientImpl : public RateLimitClient, public Logger::Loggable { public: - explicit LocalRateLimitClientImpl( - GlobalRateLimitClientImpl* global_client, - Envoy::ThreadLocal::TypedSlot& buckets_cache_tls) - : global_client_(global_client), buckets_cache_tls_(buckets_cache_tls) {} + explicit LocalRateLimitClientImpl(std::shared_ptr tls_store) + : tls_store_(std::move(tls_store)) {} void createBucket(const BucketId& bucket_id, size_t id, const BucketAction& default_bucket_action, std::unique_ptr fallback_action, @@ -45,19 +44,18 @@ class LocalRateLimitClientImpl : public RateLimitClient, private: inline std::shared_ptr getBucketsCache() { - return (buckets_cache_tls_.get().has_value()) ? buckets_cache_tls_.get()->quota_buckets_ - : nullptr; + return (tls_store_->buckets_tls.get().has_value()) + ? tls_store_->buckets_tls.get()->quota_buckets_ + : nullptr; } // Lockless access to global resources via TLS. - GlobalRateLimitClientImpl* global_client_; - ThreadLocal::TypedSlot& buckets_cache_tls_; + std::shared_ptr tls_store_; }; inline std::unique_ptr -createLocalRateLimitClient(GlobalRateLimitClientImpl* global_client, - ThreadLocal::TypedSlot& buckets_cache_tls_) { - return std::make_unique(global_client, buckets_cache_tls_); +createLocalRateLimitClient(std::shared_ptr tls_store) { + return std::make_unique(std::move(tls_store)); } } // namespace RateLimitQuota diff --git a/source/extensions/filters/http/rate_limit_quota/config.cc b/source/extensions/filters/http/rate_limit_quota/config.cc index f40442d3ef357..119121b6fdc6a 100644 --- a/source/extensions/filters/http/rate_limit_quota/config.cc +++ b/source/extensions/filters/http/rate_limit_quota/config.cc @@ -64,11 +64,10 @@ Http::FilterFactoryCb RateLimitQuotaFilterFactory::createFilterFactoryFromProtoT return [&, config = std::move(config), config_with_hash_key, tls_store = std::move(tls_store), matcher = std::move(matcher)](Http::FilterChainFactoryCallbacks& callbacks) -> void { - std::unique_ptr local_client = - createLocalRateLimitClient(tls_store->global_client.get(), tls_store->buckets_tls); + std::unique_ptr local_client = createLocalRateLimitClient(tls_store); callbacks.addStreamFilter(std::make_shared( - config, context, std::move(local_client), config_with_hash_key, matcher)); + config, context, tls_store, std::move(local_client), config_with_hash_key, matcher)); }; } diff --git a/source/extensions/filters/http/rate_limit_quota/filter.h b/source/extensions/filters/http/rate_limit_quota/filter.h index 547a00f531bf5..0221a97e4b38b 100644 --- a/source/extensions/filters/http/rate_limit_quota/filter.h +++ b/source/extensions/filters/http/rate_limit_quota/filter.h @@ -14,6 +14,7 @@ #include "source/common/protobuf/utility.h" #include "source/extensions/filters/http/common/factory_base.h" #include "source/extensions/filters/http/common/pass_through_filter.h" +#include "source/extensions/filters/http/rate_limit_quota/filter_persistence.h" #include "source/extensions/filters/http/rate_limit_quota/global_client_impl.h" #include "source/extensions/filters/http/rate_limit_quota/matcher.h" #include "source/extensions/filters/http/rate_limit_quota/quota_bucket_cache.h" @@ -50,11 +51,13 @@ class RateLimitQuotaFilter : public Http::PassThroughFilter, public: RateLimitQuotaFilter(FilterConfigConstSharedPtr config, Server::Configuration::FactoryContext& factory_context, + std::shared_ptr tls_store, std::unique_ptr local_client, Grpc::GrpcServiceConfigWithHashKey config_with_hash_key, Matcher::MatchTreeSharedPtr matcher) : config_(std::move(config)), config_with_hash_key_(config_with_hash_key), - factory_context_(factory_context), matcher_(matcher), client_(std::move(local_client)), + factory_context_(factory_context), matcher_(matcher), tls_store_(std::move(tls_store)), + client_(std::move(local_client)), time_source_(factory_context.serverFactoryContext().mainThreadDispatcher().timeSource()) {} Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap&, bool end_stream) override; @@ -95,6 +98,9 @@ class RateLimitQuotaFilter : public Http::PassThroughFilter, // shouldn't be recorded. bool first_skipped_match_ = true; + // Anchors the lifetime of the global client and its resources for the + // duration of the filter. + std::shared_ptr tls_store_; // Own a local, filter-specific client to provider functions needed by worker // threads. std::unique_ptr client_; diff --git a/source/extensions/filters/http/rate_limit_quota/filter_persistence.cc b/source/extensions/filters/http/rate_limit_quota/filter_persistence.cc index f5920d6367b57..c52a6b8c5dee5 100644 --- a/source/extensions/filters/http/rate_limit_quota/filter_persistence.cc +++ b/source/extensions/filters/http/rate_limit_quota/filter_persistence.cc @@ -8,6 +8,7 @@ #include "envoy/grpc/async_client_manager.h" #include "envoy/server/factory_context.h" +#include "source/common/common/logger.h" #include "source/extensions/filters/http/rate_limit_quota/global_client_impl.h" #include "source/extensions/filters/http/rate_limit_quota/quota_bucket_cache.h" @@ -31,7 +32,16 @@ initTlsStore(const Grpc::GrpcServiceConfigWithHashKey& config_with_hash_key, // Quota bucket & global client TLS objects are created with the config and // kept alive via shared_ptr to a storage struct. The local rate limit client // in each filter instance assumes that the slot will outlive them. - std::shared_ptr tls_store = std::make_shared(context, target_address, domain); + Envoy::Event::Dispatcher& dispatcher = context.serverFactoryContext().mainThreadDispatcher(); + auto deleter = [main_dispatcher = &dispatcher](TlsStore* store) { + if (main_dispatcher->isThreadSafe()) { + delete store; + } else { + main_dispatcher->post([store]() { delete store; }); + } + }; + + std::shared_ptr tls_store(new TlsStore(context, target_address, domain), deleter); auto tl_buckets_cache = std::make_shared(std::make_shared()); tls_store->buckets_tls.set( @@ -46,7 +56,7 @@ initTlsStore(const Grpc::GrpcServiceConfigWithHashKey& config_with_hash_key, // Create the global client resource to be shared via TLS to all worker // threads (accessed through a filter-specific LocalRateLimitClient). - std::unique_ptr tl_global_client = createGlobalRateLimitClientImpl( + std::shared_ptr tl_global_client = createGlobalRateLimitClientImpl( context, domain, reporting_interval, tls_store->buckets_tls, config_with_hash_key); tls_store->global_client = std::move(tl_global_client); @@ -60,11 +70,8 @@ GlobalTlsStores::getTlsStore(const Grpc::GrpcServiceConfigWithHashKey& config_wi Server::Configuration::FactoryContext& context, absl::string_view target_address, absl::string_view domain) { TlsStoreIndex index = std::make_pair(std::string(target_address), std::string(domain)); - // Find existing TlsStore or initialize a new one. auto it = stores().find(index); if (it != stores().end()) { - ENVOY_LOG(debug, "Found existing cache & RLQS client for target ({}) and domain ({}).", - index.first, index.second); return it->second.lock(); } ENVOY_LOG(debug, "Creating a new cache & RLQS client for target ({}) and domain ({}).", @@ -76,6 +83,14 @@ GlobalTlsStores::getTlsStore(const Grpc::GrpcServiceConfigWithHashKey& config_wi return tls_store; } +void GlobalTlsStores::clearTlsStore(const std::pair& index) { + stores().erase(index); + if (stores().empty() && getEmptiedCb() != nullptr) { + getEmptiedCb()(); + getEmptiedCb() = nullptr; + } +} + } // namespace RateLimitQuota } // namespace HttpFilters } // namespace Extensions diff --git a/source/extensions/filters/http/rate_limit_quota/filter_persistence.h b/source/extensions/filters/http/rate_limit_quota/filter_persistence.h index dcd7b1a45f642..6c396265c0566 100644 --- a/source/extensions/filters/http/rate_limit_quota/filter_persistence.h +++ b/source/extensions/filters/http/rate_limit_quota/filter_persistence.h @@ -27,11 +27,6 @@ namespace RateLimitQuota { // GlobalTlsStores holds a singleton hash map of rate_limit_quota TLS stores, // indexed by their combined RLQS server targets & domains. // -// This follows the data sharing model of FactoryRegistry, and similarly does -// not guarantee thread-safety. Additions or removals of indices can only be -// done on the main thread, as part of filter factory creation and garbage -// collection respectively. -// // Note, multiple RLQS clients with different configs (e.g. timeouts) can hit // the same index (destination + domain). The global map does not guarantee // which config will be selected for the client creation. @@ -52,12 +47,23 @@ class GlobalTlsStores : public Logger::Loggable { // The global client must be cleaned up by the server main thread before // it shuts down. if (global_client != nullptr) { - main_dispatcher_.deferredDelete(std::move(global_client)); + // SharedClientDeleter wraps a shared_ptr to GlobalRateLimitClientImpl + // to allow it to be used with deferredDelete. It ensures + // deleteIsPending() is called and the shared_ptr is dropped on the main + // thread. + struct SharedClientDeleter : public Event::DeferredDeletable { + SharedClientDeleter(std::shared_ptr client) + : client_(std::move(client)) {} + void deleteIsPending() override { client_->deleteIsPending(); } + std::shared_ptr client_; + }; + main_dispatcher_.deferredDelete( + std::make_unique(std::move(global_client))); } GlobalTlsStores::clearTlsStore(std::make_pair(target_address_, domain_)); } - std::unique_ptr global_client = nullptr; + std::shared_ptr global_client = nullptr; ThreadLocal::TypedSlot buckets_tls; private: @@ -105,13 +111,7 @@ class GlobalTlsStores : public Logger::Loggable { } // Clear a specified index when it is no longer captured by any filter factories. - static void clearTlsStore(const TlsStoreIndex& index) { - stores().erase(index); - if (stores().empty() && getEmptiedCb() != nullptr) { - getEmptiedCb()(); - getEmptiedCb() = nullptr; - } - } + static void clearTlsStore(const TlsStoreIndex& index); }; } // namespace RateLimitQuota diff --git a/source/extensions/filters/http/rate_limit_quota/global_client_impl.cc b/source/extensions/filters/http/rate_limit_quota/global_client_impl.cc index 99e3efd959c28..64fdf837f9dd4 100644 --- a/source/extensions/filters/http/rate_limit_quota/global_client_impl.cc +++ b/source/extensions/filters/http/rate_limit_quota/global_client_impl.cc @@ -78,6 +78,7 @@ void GlobalRateLimitClientImpl::deleteIsPending() { async_client_->reset(); } +// Atomically read usage counters & reset them to 0. void getUsageFromBucket(const CachedBucket& cached_bucket, TimeSource& time_source, BucketQuotaUsage& usage) { std::shared_ptr cached_usage = cached_bucket.quota_usage; @@ -152,12 +153,16 @@ void GlobalRateLimitClientImpl::createBucket(const BucketId& bucket_id, size_t i std::chrono::milliseconds fallback_ttl, bool initial_request_allowed) { // Mutable to move fallback_action ownership into the main thread then into - // the created bucket. - main_dispatcher_.post([&, bucket_id, id, default_bucket_action, + // the created bucket. Captures a weak_ptr to 'this' as the posted operation + // can outlive the global client itself. + main_dispatcher_.post([weak_this = weak_from_this(), bucket_id, id, default_bucket_action, fallback_action_ptr = std::move(fallback_action), fallback_ttl, initial_request_allowed]() mutable { - createBucketImpl(bucket_id, id, default_bucket_action, std::move(fallback_action_ptr), - fallback_ttl, initial_request_allowed); + if (auto shared_this = weak_this.lock()) { + shared_this->createBucketImpl(bucket_id, id, default_bucket_action, + std::move(fallback_action_ptr), fallback_ttl, + initial_request_allowed); + } }); } @@ -274,8 +279,13 @@ void GlobalRateLimitClientImpl::onReceiveMessage(RateLimitQuotaResponsePtr&& res if (response == nullptr) { return; } - main_dispatcher_.post( - [&, response = std::move(response)]() { onQuotaResponseImpl(response.get()); }); + // Captures a weak_ptr to 'this' as the posted operation can outlive the + // global client itself. + main_dispatcher_.post([weak_this = weak_from_this(), response = std::move(response)]() { + if (auto shared_this = weak_this.lock()) { + shared_this->onQuotaResponseImpl(response.get()); + } + }); } // Updating a cached_bucket shouldn't reset the cached token bucket if the @@ -432,12 +442,17 @@ void GlobalRateLimitClientImpl::onSendReportsTimer() { void GlobalRateLimitClientImpl::startActionExpirationTimer(CachedBucket* cached_bucket, size_t id) { // Pointer safety as all writes are against the source-of-truth. - cached_bucket->action_expiration_timer = main_dispatcher_.createTimer([&, id, cached_bucket]() { - onActionExpirationTimer(cached_bucket, id); - if (callbacks_ != nullptr) { - callbacks_->onActionExpiration(); - } - }); + cached_bucket->action_expiration_timer = + main_dispatcher_.createTimer([weak_this = weak_from_this(), id, cached_bucket]() { + auto shared_this = weak_this.lock(); + if (!shared_this) { + return; + } + shared_this->onActionExpirationTimer(cached_bucket, id); + if (shared_this->callbacks_ != nullptr) { + shared_this->callbacks_->onActionExpiration(); + } + }); std::chrono::milliseconds ttl = std::chrono::duration_cast( std::chrono::seconds(cached_bucket->cached_action->quota_assignment_action() .assignment_time_to_live() @@ -521,12 +536,17 @@ void GlobalRateLimitClientImpl::onActionExpirationTimer(CachedBucket* bucket, si void GlobalRateLimitClientImpl::startFallbackExpirationTimer(CachedBucket* cached_bucket, size_t id) { // Pointer safety as all writes are against the source-of-truth. - cached_bucket->fallback_expiration_timer = main_dispatcher_.createTimer([&, id, cached_bucket]() { - onFallbackExpirationTimer(cached_bucket, id); - if (callbacks_ != nullptr) { - callbacks_->onFallbackExpiration(); - } - }); + cached_bucket->fallback_expiration_timer = + main_dispatcher_.createTimer([weak_this = weak_from_this(), id, cached_bucket]() { + auto shared_this = weak_this.lock(); + if (!shared_this) { + return; + } + shared_this->onFallbackExpirationTimer(cached_bucket, id); + if (shared_this->callbacks_ != nullptr) { + shared_this->callbacks_->onFallbackExpiration(); + } + }); cached_bucket->fallback_expiration_timer->enableTimer(cached_bucket->fallback_ttl); } diff --git a/source/extensions/filters/http/rate_limit_quota/global_client_impl.h b/source/extensions/filters/http/rate_limit_quota/global_client_impl.h index 7259973a529da..16d11ec52234e 100644 --- a/source/extensions/filters/http/rate_limit_quota/global_client_impl.h +++ b/source/extensions/filters/http/rate_limit_quota/global_client_impl.h @@ -58,6 +58,9 @@ class GlobalRateLimitClientCallbacks { class GlobalRateLimitClientImpl : public Grpc::AsyncStreamCallbacks< envoy::service::rate_limit_quota::v3::RateLimitQuotaResponse>, public Event::DeferredDeletable, + // Required to safely capture 'this' in asynchronous callbacks + // posted to the main dispatcher, as they may outlive the client. + public std::enable_shared_from_this, public Logger::Loggable { public: // Note: rlqs_client is owned directly to ensure that it does not outlive the @@ -188,14 +191,14 @@ class GlobalRateLimitClientImpl : public Grpc::AsyncStreamCallbacks< * Create a shared rate limit client. It should be shared to each worker * thread via TLS. */ -inline std::unique_ptr +inline std::shared_ptr createGlobalRateLimitClientImpl(Server::Configuration::FactoryContext& context, absl::string_view domain_name, std::chrono::milliseconds send_reports_interval, ThreadLocal::TypedSlot& buckets_tls, const Grpc::GrpcServiceConfigWithHashKey& config_with_hash_key) { Envoy::Event::Dispatcher& main_dispatcher = context.serverFactoryContext().mainThreadDispatcher(); - return std::make_unique(config_with_hash_key, context, domain_name, + return std::make_shared(config_with_hash_key, context, domain_name, send_reports_interval, buckets_tls, main_dispatcher); } diff --git a/test/extensions/filters/http/rate_limit_quota/client_test.cc b/test/extensions/filters/http/rate_limit_quota/client_test.cc index 0d3d7031c496e..2cf2ed19f53c1 100644 --- a/test/extensions/filters/http/rate_limit_quota/client_test.cc +++ b/test/extensions/filters/http/rate_limit_quota/client_test.cc @@ -14,15 +14,14 @@ #include "envoy/type/v3/ratelimit_strategy.pb.h" #include "envoy/type/v3/token_bucket.pb.h" -#include "source/common/protobuf/protobuf.h" #include "source/common/protobuf/utility.h" #include "source/extensions/filters/http/rate_limit_quota/client_impl.h" +#include "source/extensions/filters/http/rate_limit_quota/filter_persistence.h" #include "source/extensions/filters/http/rate_limit_quota/global_client_impl.h" #include "source/extensions/filters/http/rate_limit_quota/quota_bucket_cache.h" #include "test/extensions/filters/http/rate_limit_quota/client_test_utils.h" #include "test/mocks/grpc/mocks.h" -#include "test/mocks/upstream/cluster_info.h" #include "test/test_common/logging.h" #include "absl/container/flat_hash_map.h" @@ -102,16 +101,21 @@ class GlobalClientTest : public ::testing::Test { void SetUp() override { mock_stream_client = std::make_unique(); - buckets_tls_ = std::make_unique>( - mock_stream_client->context_.server_factory_context_.thread_local_); + + mock_stream_client->expectClientCreationWithFactory(); + tls_store_ = std::make_shared(mock_stream_client->context_, + "mock_target", mock_domain_); + + // Use the TlsStore's buckets_tls. auto initial_tl_buckets_cache = std::make_shared(std::make_shared()); - buckets_tls_->set([initial_tl_buckets_cache](Unused) { return initial_tl_buckets_cache; }); + tls_store_->buckets_tls.set( + [initial_tl_buckets_cache](Unused) { return initial_tl_buckets_cache; }); - mock_stream_client->expectClientCreationWithFactory(); - global_client_ = std::make_unique( + global_client_ = std::make_shared( mock_stream_client->config_with_hash_key_, mock_stream_client->context_, mock_domain_, - reporting_interval_, *buckets_tls_, *mock_stream_client->dispatcher_); + reporting_interval_, tls_store_->buckets_tls, *mock_stream_client->dispatcher_); + tls_store_->global_client = global_client_; // Set callbacks to handle asynchronous timing. auto callbacks = std::make_unique(); cb_ptr_ = callbacks.get(); @@ -121,12 +125,22 @@ class GlobalClientTest : public ::testing::Test { } void TearDown() override { - // Normally called by TlsStore destructor as part of filter factory cb deletion. - mock_stream_client->dispatcher_->deferredDelete(std::move(global_client_)); + if (global_client_ != nullptr) { + struct SharedClientDeleter : public Event::DeferredDeletable { + SharedClientDeleter(std::shared_ptr client) + : client_(std::move(client)) {} + void deleteIsPending() override { client_->deleteIsPending(); } + std::shared_ptr client_; + }; + mock_stream_client->dispatcher_->deferredDelete( + std::make_unique(std::move(global_client_))); + } + tls_store_ = nullptr; } std::unique_ptr mock_stream_client = nullptr; - std::unique_ptr global_client_ = nullptr; + std::shared_ptr global_client_ = nullptr; + std::shared_ptr tls_store_ = nullptr; ThreadLocal::TypedSlotPtr buckets_tls_ = nullptr; GlobalClientCallbacks* cb_ptr_ = nullptr; @@ -312,7 +326,7 @@ TEST_F(GlobalClientTest, TestInitialCreation) { std::chrono::milliseconds::zero(), true); // Expect the bucket cache to update with a new bucket quickly. cb_ptr_->waitForExpectedBuckets(); - auto cache_ref = buckets_tls_->get(); + auto cache_ref = tls_store_->buckets_tls.get(); ASSERT_TRUE(cache_ref.has_value()); ASSERT_TRUE(cache_ref->quota_buckets_); ASSERT_EQ(cache_ref->quota_buckets_->size(), 1); @@ -370,7 +384,7 @@ TEST_F(GlobalClientTest, TestCreationWithDefaultDeny) { std::chrono::milliseconds::zero(), false); // Expect the bucket cache to update with a new bucket quickly. cb_ptr_->waitForExpectedBuckets(); - auto cache_ref = buckets_tls_->get(); + auto cache_ref = tls_store_->buckets_tls.get(); ASSERT_TRUE(cache_ref.has_value()); ASSERT_TRUE(cache_ref->quota_buckets_); ASSERT_EQ(cache_ref->quota_buckets_->size(), 1); @@ -410,7 +424,7 @@ TEST_F(GlobalClientTest, BasicUsageReporting) { std::chrono::milliseconds::zero(), true); cb_ptr_->waitForExpectedBuckets(); // Get bucket from TLS. - std::shared_ptr quota_usage = getQuotaUsage(*buckets_tls_, sample_id_hash_); + std::shared_ptr quota_usage = getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_); setAtomic(1, quota_usage->num_requests_allowed); setAtomic(2, quota_usage->num_requests_denied); @@ -426,7 +440,7 @@ TEST_F(GlobalClientTest, BasicUsageReporting) { waitForNotification(cb_ptr_->report_sent); // After the expected report goes out, the atomics should be reset for the // next aggregation interval. - quota_usage = getQuotaUsage(*buckets_tls_, sample_id_hash_); + quota_usage = getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_); EXPECT_EQ(quota_usage->num_requests_allowed.load(std::memory_order_relaxed), 0); EXPECT_EQ(quota_usage->num_requests_denied.load(std::memory_order_relaxed), 0); } @@ -455,7 +469,7 @@ TEST_F(GlobalClientTest, TestStreamCreationFailures) { std::chrono::milliseconds::zero(), true); cb_ptr_->waitForExpectedBuckets(); // Bucket should be created, even with the stream failure. - std::shared_ptr quota_usage = getQuotaUsage(*buckets_tls_, sample_id_hash_); + std::shared_ptr quota_usage = getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_); EXPECT_GT(quota_usage->num_requests_allowed, 0); EXPECT_LT(quota_usage->num_requests_allowed, 4); // With the timer cb, the stream should be reattempted, fail starting and @@ -465,7 +479,7 @@ TEST_F(GlobalClientTest, TestStreamCreationFailures) { waitForNotification(cb_ptr_->report_sent); // Refresh state from the buckets cache in TLS. Expect the atomics to have // reset after the dropped reports. - quota_usage = getQuotaUsage(*buckets_tls_, sample_id_hash_); + quota_usage = getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_); EXPECT_EQ(quota_usage->num_requests_allowed, 0); setAtomic(4 + i, quota_usage->num_requests_allowed); } @@ -478,7 +492,7 @@ TEST_F(GlobalClientTest, TestStreamCreationFailures) { mock_stream_client->timer_->invokeCallback(); waitForNotification(cb_ptr_->report_sent); - quota_usage = getQuotaUsage(*buckets_tls_, sample_id_hash_); + quota_usage = getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_); EXPECT_EQ(quota_usage->num_requests_allowed.load(std::memory_order_relaxed), 0); } @@ -501,7 +515,7 @@ TEST_F(GlobalClientTest, TestStreamFailureMidUse) { std::chrono::milliseconds::zero(), true); cb_ptr_->waitForExpectedBuckets(); // Get bucket from TLS. - std::shared_ptr quota_usage = getQuotaUsage(*buckets_tls_, sample_id_hash_); + std::shared_ptr quota_usage = getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_); setAtomic(1, quota_usage->num_requests_allowed); setAtomic(2, quota_usage->num_requests_denied); @@ -515,7 +529,7 @@ TEST_F(GlobalClientTest, TestStreamFailureMidUse) { // After the expected report goes out, the atomics should be reset for the // next aggregation interval. - quota_usage = getQuotaUsage(*buckets_tls_, sample_id_hash_); + quota_usage = getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_); EXPECT_EQ(quota_usage->num_requests_allowed.load(std::memory_order_relaxed), 0); EXPECT_EQ(quota_usage->num_requests_denied.load(std::memory_order_relaxed), 0); // Close the stream to show the internal restart mechanism. @@ -541,8 +555,9 @@ TEST_F(GlobalClientTest, TestStreamFailureMidUse) { // Wait for the second bucket creation to complete. cb_ptr_->waitForExpectedBuckets(); - quota_usage = getQuotaUsage(*buckets_tls_, sample_id_hash_); - std::shared_ptr quota_usage2 = getQuotaUsage(*buckets_tls_, sample_id_hash2); + quota_usage = getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_); + std::shared_ptr quota_usage2 = + getQuotaUsage(tls_store_->buckets_tls, sample_id_hash2); EXPECT_EQ(quota_usage->num_requests_allowed.load(std::memory_order_relaxed), 3); EXPECT_EQ(quota_usage->num_requests_denied.load(std::memory_order_relaxed), 4); EXPECT_EQ(quota_usage2->num_requests_allowed.load(std::memory_order_relaxed), 1); @@ -599,9 +614,9 @@ TEST_F(GlobalClientTest, TestBasicResponseProcessing) { std::chrono::milliseconds::zero(), true); cb_ptr_->waitForExpectedBuckets(); - setAtomic(1, getQuotaUsage(*buckets_tls_, sample_id_hash_)->num_requests_allowed); - setAtomic(2, getQuotaUsage(*buckets_tls_, sample_id_hash2)->num_requests_allowed); - setAtomic(3, getQuotaUsage(*buckets_tls_, sample_id_hash3)->num_requests_allowed); + setAtomic(1, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_)->num_requests_allowed); + setAtomic(2, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash2)->num_requests_allowed); + setAtomic(3, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash3)->num_requests_allowed); RateLimitQuotaUsageReports expected_reports = buildReports( std::vector{{/*allowed=*/1, /*denied=*/0, /*bucket_id=*/sample_bucket_id_}, @@ -630,15 +645,17 @@ TEST_F(GlobalClientTest, TestBasicResponseProcessing) { waitForNotification(cb_ptr_->response_processed); // Expect the buckets in TLS to have matching assignments. - std::shared_ptr deny_all_bucket = getBucket(*buckets_tls_, sample_id_hash_); + std::shared_ptr deny_all_bucket = + getBucket(tls_store_->buckets_tls, sample_id_hash_); ASSERT_TRUE(deny_all_bucket->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(*deny_all_bucket->cached_action, deny_action)); - std::shared_ptr allow_all_bucket = getBucket(*buckets_tls_, sample_id_hash2); + std::shared_ptr allow_all_bucket = + getBucket(tls_store_->buckets_tls, sample_id_hash2); ASSERT_TRUE(allow_all_bucket->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(*allow_all_bucket->cached_action, allow_action)); - std::shared_ptr token_bucket = getBucket(*buckets_tls_, sample_id_hash3); + std::shared_ptr token_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash3); ASSERT_TRUE(token_bucket->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(*token_bucket->cached_action, token_bucket_action)); @@ -690,7 +707,7 @@ TEST_F(GlobalClientTest, TestDuplicateTokenBucket) { /*initial_request_allowed=*/true); cb_ptr_->waitForExpectedBuckets(); - setAtomic(1, getQuotaUsage(*buckets_tls_, sample_id_hash_)->num_requests_allowed); + setAtomic(1, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_)->num_requests_allowed); mock_stream_client->timer_->invokeCallback(); waitForNotification(cb_ptr_->report_sent); @@ -706,7 +723,7 @@ TEST_F(GlobalClientTest, TestDuplicateTokenBucket) { waitForNotification(cb_ptr_->response_processed); // Verify the integrity of the token bucket configuration. - std::shared_ptr token_bucket = getBucket(*buckets_tls_, sample_id_hash_); + std::shared_ptr token_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash_); ASSERT_TRUE(token_bucket->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(*token_bucket->cached_action, token_bucket_action)); @@ -726,7 +743,7 @@ TEST_F(GlobalClientTest, TestDuplicateTokenBucket) { waitForNotification(cb_ptr_->response_processed); // Get the updated token bucket out of the cache. - token_bucket = getBucket(*buckets_tls_, sample_id_hash_); + token_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash_); ASSERT_TRUE(token_bucket->cached_action); // Confirm that the action is still the same. EXPECT_TRUE(unordered_differencer_.Equals(*token_bucket->cached_action, token_bucket_action)); @@ -748,7 +765,7 @@ TEST_F(GlobalClientTest, TestDuplicateTokenBucket) { waitForNotification(cb_ptr_->action_expired); // Confirm the final token bucket state only has the default action. - token_bucket = getBucket(*buckets_tls_, sample_id_hash_); + token_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash_); EXPECT_FALSE(token_bucket->cached_action); } @@ -778,8 +795,8 @@ TEST_F(GlobalClientTest, TestResponseProcessingForNonExistentBucket) { std::chrono::milliseconds::zero(), true); cb_ptr_->waitForExpectedBuckets(); - EXPECT_OK(tryGetBucket(*buckets_tls_, sample_id_hash_)); - EXPECT_FALSE(tryGetBucket(*buckets_tls_, sample_id_hash2).ok()); + EXPECT_OK(tryGetBucket(tls_store_->buckets_tls, sample_id_hash_)); + EXPECT_FALSE(tryGetBucket(tls_store_->buckets_tls, sample_id_hash2).ok()); auto deny_action = buildBlanketAction(sample_bucket_id_, true); auto allow_action = buildBlanketAction(sample_bucket_id2, false); @@ -793,12 +810,13 @@ TEST_F(GlobalClientTest, TestResponseProcessingForNonExistentBucket) { // Expect the second bucket hash to not be in the bucket cache as it wasn't // there before the response included it. - std::shared_ptr deny_all_bucket = getBucket(*buckets_tls_, sample_id_hash_); + std::shared_ptr deny_all_bucket = + getBucket(tls_store_->buckets_tls, sample_id_hash_); ASSERT_TRUE(deny_all_bucket->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(*deny_all_bucket->cached_action, deny_action)); absl::StatusOr> allow_all_bucket = - tryGetBucket(*buckets_tls_, sample_id_hash2); + tryGetBucket(tls_store_->buckets_tls, sample_id_hash2); EXPECT_FALSE(allow_all_bucket.ok()); EXPECT_EQ(mock_stream_client->expiration_timers_.size(), 1); @@ -843,8 +861,8 @@ TEST_F(GlobalClientTest, TestResponseEdgeCases) { std::chrono::milliseconds::zero(), true); cb_ptr_->waitForExpectedBuckets(); - setAtomic(1, getQuotaUsage(*buckets_tls_, sample_id_hash_)->num_requests_allowed); - setAtomic(1, getQuotaUsage(*buckets_tls_, sample_id_hash2)->num_requests_allowed); + setAtomic(1, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_)->num_requests_allowed); + setAtomic(1, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash2)->num_requests_allowed); RateLimitQuotaUsageReports expected_reports = buildReports( std::vector{{/*allowed=*/1, /*denied=*/0, /*bucket_id=*/sample_bucket_id_}, @@ -889,11 +907,11 @@ TEST_F(GlobalClientTest, TestResponseEdgeCases) { }); // Expect the buckets in TLS to still only have the default action. - std::shared_ptr bucket1 = getBucket(*buckets_tls_, sample_id_hash_); + std::shared_ptr bucket1 = getBucket(tls_store_->buckets_tls, sample_id_hash_); ASSERT_FALSE(bucket1->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(bucket1->default_action, default_allow_action)); - std::shared_ptr bucket2 = getBucket(*buckets_tls_, sample_id_hash2); + std::shared_ptr bucket2 = getBucket(tls_store_->buckets_tls, sample_id_hash2); ASSERT_FALSE(bucket2->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(bucket2->default_action, default_allow_action2)); @@ -964,9 +982,9 @@ TEST_F(GlobalClientTest, TestExpirationAndFallback) { mock_stream_client->stream_, sendMessageRaw_(Grpc::ProtoBufferEqIgnoreRepeatedFieldOrdering(expected_reports), false)); - setAtomic(1, getQuotaUsage(*buckets_tls_, sample_id_hash_)->num_requests_allowed); - setAtomic(1, getQuotaUsage(*buckets_tls_, sample_id_hash2)->num_requests_allowed); - setAtomic(1, getQuotaUsage(*buckets_tls_, sample_id_hash3)->num_requests_denied); + setAtomic(1, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_)->num_requests_allowed); + setAtomic(1, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash2)->num_requests_allowed); + setAtomic(1, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash3)->num_requests_denied); mock_stream_client->timer_->invokeCallback(); waitForNotification(cb_ptr_->report_sent); @@ -990,17 +1008,19 @@ TEST_F(GlobalClientTest, TestExpirationAndFallback) { waitForNotification(cb_ptr_->response_processed); // Expect the buckets in TLS to have matching assignments. - std::shared_ptr deny_all_bucket = getBucket(*buckets_tls_, sample_id_hash_); + std::shared_ptr deny_all_bucket = + getBucket(tls_store_->buckets_tls, sample_id_hash_); ASSERT_TRUE(deny_all_bucket && deny_all_bucket->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(*deny_all_bucket->cached_action, deny_action)); - std::shared_ptr token_bucket = getBucket(*buckets_tls_, sample_id_hash2); + std::shared_ptr token_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash2); ASSERT_TRUE(token_bucket && token_bucket->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(*token_bucket->cached_action, token_bucket_action)); ASSERT_TRUE(token_bucket->fallback_action); EXPECT_TRUE(unordered_differencer_.Equals(*token_bucket->fallback_action, fallback_tb_action)); - std::shared_ptr allow_all_bucket = getBucket(*buckets_tls_, sample_id_hash3); + std::shared_ptr allow_all_bucket = + getBucket(tls_store_->buckets_tls, sample_id_hash3); ASSERT_TRUE(allow_all_bucket && allow_all_bucket->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(*allow_all_bucket->cached_action, allow_action)); @@ -1015,8 +1035,8 @@ TEST_F(GlobalClientTest, TestExpirationAndFallback) { waitForNotification(cb_ptr_->action_expired); // Get the new cached bucket replacing the expired one. - EXPECT_NE(token_bucket, getBucket(*buckets_tls_, sample_id_hash2)); - token_bucket = getBucket(*buckets_tls_, sample_id_hash2); + EXPECT_NE(token_bucket, getBucket(tls_store_->buckets_tls, sample_id_hash2)); + token_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash2); // Expect a fallback timer for the expired bucket while the other two are // unaffected. ASSERT_EQ(mock_stream_client->fallback_timers_.size(), 1); @@ -1039,8 +1059,8 @@ TEST_F(GlobalClientTest, TestExpirationAndFallback) { waitForNotification(cb_ptr_->fallback_expired); // Get the new cached bucket replacing the expired one. - EXPECT_NE(token_bucket, getBucket(*buckets_tls_, sample_id_hash2)); - token_bucket = getBucket(*buckets_tls_, sample_id_hash2); + EXPECT_NE(token_bucket, getBucket(tls_store_->buckets_tls, sample_id_hash2)); + token_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash2); // Expect the second bucket to have lost its fallback timer & cached action. ASSERT_FALSE(token_bucket->cached_action); ASSERT_FALSE(token_bucket->token_bucket_limiter); @@ -1055,8 +1075,8 @@ TEST_F(GlobalClientTest, TestExpirationAndFallback) { waitForNotification(cb_ptr_->action_expired); // Get the new cached bucket replacing the expired one. - EXPECT_NE(deny_all_bucket, getBucket(*buckets_tls_, sample_id_hash_)); - deny_all_bucket = getBucket(*buckets_tls_, sample_id_hash_); + EXPECT_NE(deny_all_bucket, getBucket(tls_store_->buckets_tls, sample_id_hash_)); + deny_all_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash_); // Don't expect a fallback timer for the first bucket. ASSERT_EQ(mock_stream_client->fallback_timers_.size(), 1); // Expect the first bucket to have lost its cached action. @@ -1071,8 +1091,8 @@ TEST_F(GlobalClientTest, TestExpirationAndFallback) { waitForNotification(cb_ptr_->action_expired); // Get the new cached bucket replacing the expired one. - EXPECT_NE(allow_all_bucket, getBucket(*buckets_tls_, sample_id_hash3)); - allow_all_bucket = getBucket(*buckets_tls_, sample_id_hash3); + EXPECT_NE(allow_all_bucket, getBucket(tls_store_->buckets_tls, sample_id_hash3)); + allow_all_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash3); // Expect a fallback timer for the third bucket. ASSERT_EQ(mock_stream_client->fallback_timers_.size(), 2); // Expect the third bucket to have replaced its cached action with the @@ -1097,7 +1117,7 @@ TEST_F(GlobalClientTest, TestExpirationAndFallback) { waitForNotification(cb_ptr_->response_processed); // Re-get the updated third bucket after the TLS push. - allow_all_bucket = getBucket(*buckets_tls_, sample_id_hash3); + allow_all_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash3); // Expect the third bucket to have a new cached action & no fallback timer. ASSERT_EQ(mock_stream_client->expiration_timers_.size(), 4); ASSERT_TRUE(allow_all_bucket && allow_all_bucket->cached_action); @@ -1115,16 +1135,16 @@ TEST_F(GlobalClientTest, TestExpirationAndFallback) { waitForNotification(cb_ptr_->action_expired); // Get the new cached bucket replacing the expired one. - EXPECT_NE(allow_all_bucket, getBucket(*buckets_tls_, sample_id_hash3)); - allow_all_bucket = getBucket(*buckets_tls_, sample_id_hash3); + EXPECT_NE(allow_all_bucket, getBucket(tls_store_->buckets_tls, sample_id_hash3)); + allow_all_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash3); Event::MockTimer* replacement_allow_all_fallback_timer = RateLimitTestClient::assertMockTimer(allow_all_bucket->fallback_expiration_timer.get()); replacement_allow_all_fallback_timer->invokeCallback(); waitForNotification(cb_ptr_->fallback_expired); // Get the new cached bucket replacing the expired one. - EXPECT_NE(allow_all_bucket, getBucket(*buckets_tls_, sample_id_hash3)); - allow_all_bucket = getBucket(*buckets_tls_, sample_id_hash3); + EXPECT_NE(allow_all_bucket, getBucket(tls_store_->buckets_tls, sample_id_hash3)); + allow_all_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash3); // Expect the third bucket to have lost its fallback timer & cached action. ASSERT_FALSE(allow_all_bucket->cached_action); ASSERT_FALSE(allow_all_bucket->fallback_expiration_timer); @@ -1160,7 +1180,7 @@ TEST_F(GlobalClientTest, TestFallbackToDuplicateTokenBucket) { std::chrono::seconds(300), true); cb_ptr_->waitForExpectedBuckets(); - setAtomic(1, getQuotaUsage(*buckets_tls_, sample_id_hash_)->num_requests_allowed); + setAtomic(1, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_)->num_requests_allowed); mock_stream_client->timer_->invokeCallback(); waitForNotification(cb_ptr_->report_sent); @@ -1176,7 +1196,7 @@ TEST_F(GlobalClientTest, TestFallbackToDuplicateTokenBucket) { waitForNotification(cb_ptr_->response_processed); // Expect the bucket in TLS to have a matching assignment. - std::shared_ptr token_bucket = getBucket(*buckets_tls_, sample_id_hash_); + std::shared_ptr token_bucket = getBucket(tls_store_->buckets_tls, sample_id_hash_); ASSERT_TRUE(token_bucket && token_bucket->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(*token_bucket->cached_action, token_bucket_action)); ASSERT_TRUE(token_bucket->fallback_action); @@ -1194,7 +1214,8 @@ TEST_F(GlobalClientTest, TestFallbackToDuplicateTokenBucket) { // Get the new CachedBucket, which should have carried over the existing token // bucket. - std::shared_ptr new_token_bucket = getBucket(*buckets_tls_, sample_id_hash_); + std::shared_ptr new_token_bucket = + getBucket(tls_store_->buckets_tls, sample_id_hash_); EXPECT_NE(token_bucket.get(), new_token_bucket.get()); // Expect a fallback timer for the expired bucket. @@ -1235,12 +1256,12 @@ TEST_F(GlobalClientTest, TestAbandonAction) { std::chrono::milliseconds::zero(), true); cb_ptr_->waitForExpectedBuckets(); - setAtomic(1, getQuotaUsage(*buckets_tls_, sample_id_hash_)->num_requests_allowed); + setAtomic(1, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_)->num_requests_allowed); mock_stream_client->timer_->invokeCallback(); waitForNotification(cb_ptr_->report_sent); // Expect the bucket in TLS. - std::shared_ptr bucket_before = getBucket(*buckets_tls_, sample_id_hash_); + std::shared_ptr bucket_before = getBucket(tls_store_->buckets_tls, sample_id_hash_); ASSERT_TRUE(bucket_before); // Test abandon-action response handling. @@ -1253,7 +1274,7 @@ TEST_F(GlobalClientTest, TestAbandonAction) { waitForNotification(cb_ptr_->response_processed); // Expect the bucket to be wiped. - std::shared_ptr bucket_after = getBucket(*buckets_tls_, sample_id_hash_); + std::shared_ptr bucket_after = getBucket(tls_store_->buckets_tls, sample_id_hash_); ASSERT_FALSE(bucket_after); } @@ -1277,7 +1298,7 @@ TEST_F(GlobalClientTest, TestResponseBucketMissingId) { std::chrono::milliseconds::zero(), true); cb_ptr_->waitForExpectedBuckets(); - setAtomic(1, getQuotaUsage(*buckets_tls_, sample_id_hash_)->num_requests_allowed); + setAtomic(1, getQuotaUsage(tls_store_->buckets_tls, sample_id_hash_)->num_requests_allowed); RateLimitQuotaUsageReports expected_reports = buildReports( std::vector{{/*allowed=*/1, /*denied=*/0, /*bucket_id=*/sample_bucket_id_}}); @@ -1303,7 +1324,8 @@ TEST_F(GlobalClientTest, TestResponseBucketMissingId) { }); // Expect the deny-all bucket to have made it into TLS. - std::shared_ptr deny_all_bucket = getBucket(*buckets_tls_, sample_id_hash_); + std::shared_ptr deny_all_bucket = + getBucket(tls_store_->buckets_tls, sample_id_hash_); ASSERT_TRUE(deny_all_bucket->cached_action); EXPECT_TRUE(unordered_differencer_.Equals(*deny_all_bucket->cached_action, deny_action)); } @@ -1315,7 +1337,7 @@ class LocalClientTest : public GlobalClientTest { void SetUp() override { GlobalClientTest::SetUp(); // Create the local client for testing. - local_client_ = std::make_unique(global_client_.get(), *buckets_tls_); + local_client_ = std::make_unique(tls_store_); } std::unique_ptr local_client_ = nullptr; diff --git a/test/extensions/filters/http/rate_limit_quota/filter_persistence_test.cc b/test/extensions/filters/http/rate_limit_quota/filter_persistence_test.cc index e71f0e199f66e..8366880567e45 100644 --- a/test/extensions/filters/http/rate_limit_quota/filter_persistence_test.cc +++ b/test/extensions/filters/http/rate_limit_quota/filter_persistence_test.cc @@ -13,6 +13,7 @@ #include "test/integration/fake_upstream.h" #include "test/integration/http_integration.h" #include "test/integration/integration_stream_decoder.h" +#include "test/mocks/server/factory_context.h" #include "test/test_common/simulated_time_system.h" #include "test/test_common/utility.h" @@ -124,6 +125,11 @@ class FilterPersistenceTest : public Event::TestUsingSimulatedTime, HttpIntegrationTest::initialize(); } + void resetTlsStoreEmptiedCb() { + tls_store_emptied_ = std::make_unique(); + GlobalTlsStores::registerEmptiedCb([&]() { tls_store_emptied_->Notify(); }); + } + // The RLQS upstream shouldn't be autonomous as it will handle the long-lived // RLQS stream. void createUpstreams() override { @@ -156,8 +162,7 @@ class FilterPersistenceTest : public Event::TestUsingSimulatedTime, rlqs_upstream_refs.rlqs_cluster_->set_name(absl::StrCat("rlqs_upstream_", i)); } }); - tls_store_emptied_ = std::make_unique(); - GlobalTlsStores::registerEmptiedCb([&]() { tls_store_emptied_->Notify(); }); + resetTlsStoreEmptiedCb(); } void updateConfigInPlace(std::function modifier) { @@ -195,12 +200,14 @@ class FilterPersistenceTest : public Event::TestUsingSimulatedTime, hcm_config.clear_http_filters(); hcm_filter->mutable_typed_config()->PackFrom(hcm_config); }); - // Wait for all TLS stores to be deleted now that the filter factories are gone. - ASSERT_TRUE(waitForAllTlsStoreDeletions()); } void cleanUp() { - wipeFilters(); + // Cleanup leftover filters and their TLS stores. + if (tls_store_emptied_ != nullptr && !tls_store_emptied_->HasBeenNotified()) { + wipeFilters(); + ASSERT_TRUE(waitForAllTlsStoreDeletions()); + } for (auto& rlqs_upstream : rlqs_upstreams_) { if (rlqs_upstream.rlqs_connection_ != nullptr) { ASSERT_TRUE(rlqs_upstream.rlqs_connection_->close()); @@ -535,6 +542,90 @@ domain: "test_domain" ASSERT_EQ(sendRequest(&headers), "429"); } +// Verify that the callback registered via registerEmptiedCb fires every time +// the stores map becomes empty. +TEST_P(FilterPersistenceTest, TestEmptiedCallbackFiresMultipleTimes) { + auto remove_rlqs = [&](envoy::config::bootstrap::v3::Bootstrap& bootstrap) { + auto* listener = bootstrap.mutable_static_resources()->mutable_listeners(0); + auto* hcm_filter = listener->mutable_filter_chains(0)->mutable_filters(0); + HttpConnectionManager hcm_config; + hcm_filter->mutable_typed_config()->UnpackTo(&hcm_config); + auto* filters = hcm_config.mutable_http_filters(); + for (int i = 0; i < filters->size();) { + if (filters->Get(i).name() == "envoy.filters.http.rate_limit_quota") { + filters->DeleteSubrange(i, 1); + } else { + i++; + } + } + hcm_filter->mutable_typed_config()->PackFrom(hcm_config); + }; + + // 1. Initial removal of the RLQS filter to trigger the first call of the callback. + updateConfigInPlace(remove_rlqs); + ASSERT_TRUE(waitForAllTlsStoreDeletions()); + + // 2. Add the filter back and verify it works. + updateConfigInPlace([&](envoy::config::bootstrap::v3::Bootstrap& bootstrap) { + auto* listener = bootstrap.mutable_static_resources()->mutable_listeners(0); + auto* hcm_filter = listener->mutable_filter_chains(0)->mutable_filters(0); + HttpConnectionManager hcm_config; + hcm_filter->mutable_typed_config()->UnpackTo(&hcm_config); + // Prepend the RLQS filter. + (void)hcm_config.add_http_filters(); + for (int i = hcm_config.http_filters_size() - 1; i > 0; --i) { + hcm_config.mutable_http_filters()->SwapElements(i, i - 1); + } + TestUtility::loadFromYaml(kDefaultRateLimitQuotaFilter, *hcm_config.mutable_http_filters(0)); + hcm_filter->mutable_typed_config()->PackFrom(hcm_config); + }); + absl::flat_hash_map headers = {{"environment", "staging"}}; + ASSERT_EQ(sendRequest(&headers), "200"); + // Reset the notification as we now expect another future call to the emptiedCb + resetTlsStoreEmptiedCb(); + + // 3. Remove the filter again. The callback should be called a second time. + updateConfigInPlace(remove_rlqs); + ASSERT_TRUE(waitForAllTlsStoreDeletions()); + // No need to reset the notification as we're leaving the config empty of RLQS filters. +} + +TEST_P(FilterPersistenceTest, TestDeletingTlsStoreFromWorkerThread) { + RateLimitQuotaUsageReports expected_reports; + TestUtility::loadFromYaml(R"EOF( +domain: "test_domain" +bucket_quota_usages: + bucket_id: + bucket: + "test_key_1": + "test_value_1" + "test_key_2": + "test_value_2" + num_requests_allowed: 1 +)EOF", + expected_reports); + // The first request should trigger an immediate usage report. The + // no-assignment behavior is ALLOW_ALL so the first request should be allowed. + absl::flat_hash_map headers = {{"environment", "staging"}}; + ASSERT_EQ(sendRequest(&headers), "200"); + expectRlqsUsageReports(0, expected_reports, true); + + auto ctx = testing::NiceMock(); + // The TlsStore should already exist, so parameters needed for initialization aren't important. + // Capture a shared_ptr to mimic deletion while a filter instance's shared_ptr outliving + // the filter factory's. + std::shared_ptr tls_store = GlobalTlsStores::getTlsStore( + Envoy::Grpc::GrpcServiceConfigWithHashKey(), ctx, + rlqs_upstreams_[0].rlqs_upstream_->localAddress()->asStringView(), "test_domain"); + + // Push a config update to remove all RLQS filters and wait for confirmation that the config + // loaded via stat tracking. + wipeFilters(); + EXPECT_EQ(tls_store.use_count(), 1); + tls_store = nullptr; + EXPECT_TRUE(waitForAllTlsStoreDeletions()); +} + } // namespace } // namespace RateLimitQuota } // namespace HttpFilters diff --git a/test/extensions/filters/http/rate_limit_quota/filter_test.cc b/test/extensions/filters/http/rate_limit_quota/filter_test.cc index 6e32532b236e8..06c9b4b15d25c 100644 --- a/test/extensions/filters/http/rate_limit_quota/filter_test.cc +++ b/test/extensions/filters/http/rate_limit_quota/filter_test.cc @@ -118,8 +118,12 @@ class FilterTest : public testing::Test { Grpc::GrpcServiceConfigWithHashKey config_with_hash_key = Grpc::GrpcServiceConfigWithHashKey(filter_config_->rlqs_server()); + // Initialize the TLS store for the filter to hold. + tls_store_ = + std::make_shared(context_, "mock_target", "mock_domain"); + mock_local_client_ = new MockRateLimitClient(); - filter_ = std::make_unique(filter_config_, context_, + filter_ = std::make_unique(filter_config_, context_, tls_store_, absl::WrapUnique(mock_local_client_), config_with_hash_key, match_tree_); if (set_callback) { @@ -186,6 +190,7 @@ class FilterTest : public testing::Test { NiceMock decoder_callbacks_; MockRateLimitClient* mock_local_client_ = nullptr; + std::shared_ptr tls_store_ = nullptr; FilterConfigConstSharedPtr filter_config_; FilterConfig config_; Matcher::MatchTreeSharedPtr match_tree_ = nullptr; @@ -985,8 +990,9 @@ TEST_F(FilterTest, DenyResponseWithExplicitGrpcStatus) { Grpc::GrpcServiceConfigWithHashKey config_with_hash_key = Grpc::GrpcServiceConfigWithHashKey(filter_config_->rlqs_server()); + tls_store_ = std::make_shared(context_, "mock_target", "mock_domain"); mock_local_client_ = new MockRateLimitClient(); - filter_ = std::make_unique(filter_config_, context_, + filter_ = std::make_unique(filter_config_, context_, tls_store_, absl::WrapUnique(mock_local_client_), config_with_hash_key, match_tree_); filter_->setDecoderFilterCallbacks(decoder_callbacks_); @@ -1160,8 +1166,9 @@ TEST_F(FilterTest, CustomGrpcMessageTest) { Grpc::GrpcServiceConfigWithHashKey config_with_hash_key = Grpc::GrpcServiceConfigWithHashKey(filter_config_->rlqs_server()); + tls_store_ = std::make_shared(context_, "mock_target", "mock_domain"); mock_local_client_ = new MockRateLimitClient(); - filter_ = std::make_unique(filter_config_, context_, + filter_ = std::make_unique(filter_config_, context_, tls_store_, absl::WrapUnique(mock_local_client_), config_with_hash_key, match_tree_); filter_->setDecoderFilterCallbacks(decoder_callbacks_);